DbFactory.cs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. using System.Reflection;
  2. using FluentResults;
  3. using InABox.Clients;
  4. using InABox.Configuration;
  5. using InABox.Core;
  6. using InABox.Mail;
  7. using InABox.Scripting;
  8. namespace InABox.Database;
  9. public class DatabaseMetadata : BaseObject, IGlobalConfigurationSettings
  10. {
  11. public Guid DatabaseID { get; set; } = Guid.NewGuid();
  12. }
  13. public class DbLockedException : Exception
  14. {
  15. public DbLockedException(): base("Database is read-only due to PRS license expiry.") { }
  16. }
  17. public static class DbFactory
  18. {
  19. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  20. private static DatabaseMetadata MetaData { get; set; } = new();
  21. public static Guid ID
  22. {
  23. get => MetaData.DatabaseID;
  24. set
  25. {
  26. MetaData.DatabaseID = value;
  27. SaveMetadata();
  28. }
  29. }
  30. private static IProviderFactory? _providerFactory;
  31. public static IProviderFactory ProviderFactory
  32. {
  33. get => _providerFactory ?? throw new Exception("Provider is not set");
  34. set => _providerFactory = value;
  35. }
  36. public static bool IsProviderSet => _providerFactory is not null;
  37. public static string? ColorScheme { get; set; }
  38. public static byte[]? Logo { get; set; }
  39. public static Type DefaultStore { get; set; } = typeof(Store<>);
  40. // See notes in Request.DatabaseInfo class
  41. // Once RPC transport is stable, these settings need
  42. // to be removed
  43. public static int RestPort { get; set; }
  44. public static int RPCPort { get; set; }
  45. /// <summary>
  46. /// Return every <see cref="IPersistent"/> entity in <see cref="CoreUtils.Entities"/>.
  47. /// </summary>
  48. public static IEnumerable<Type> Entities => CoreUtils.Entities.Where(x => x.HasInterface<IPersistent>());
  49. public static Type[] Stores
  50. {
  51. get => stores;
  52. set => SetStoreTypes(value);
  53. }
  54. public static string? EmailAddress { get; set; }
  55. public static ICoreMailer? Mailer { get; set; }
  56. public static DateTime Expiry { get; set; }
  57. public static IProvider NewProvider(Logger logger) => ProviderFactory.NewProvider(logger);
  58. public static void Start(Type[]? types = null)
  59. {
  60. CoreUtils.CheckLicensing();
  61. if(types is not null)
  62. {
  63. ProviderFactory.Types = types.Concat(CoreUtils.IterateTypes(typeof(CoreUtils).Assembly).Where(x => !x.IsAbstract))
  64. .Where(x => x.IsClass && !x.IsGenericType && x.IsSubclassOf(typeof(Entity)))
  65. .ToArray();
  66. }
  67. else
  68. {
  69. ProviderFactory.Types = Entities.Where(x =>
  70. x.IsClass
  71. && !x.IsGenericType
  72. && x.IsSubclassOf(typeof(Entity))
  73. ).ToArray();
  74. }
  75. // Start the provider
  76. ProviderFactory.Start();
  77. CheckMetadata();
  78. if (!DataUpdater.MigrateDatabase())
  79. {
  80. throw new Exception("Database migration failed. Aborting startup");
  81. }
  82. AssertLicense();
  83. BeginLicenseCheckTimer();
  84. InitStores();
  85. LoadScripts();
  86. }
  87. public static DatabaseVersion GetVersionSettings()
  88. {
  89. var provider = ProviderFactory.NewProvider(Logger.New());
  90. var result = provider.Query(new Filter<GlobalSettings>(x => x.Section).IsEqualTo(nameof(DatabaseVersion)))
  91. .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
  92. if(result != null)
  93. {
  94. return Serialization.Deserialize<DatabaseVersion>(result.Contents);
  95. }
  96. var settings = new GlobalSettings() { Section = nameof(DatabaseVersion), Key = "" };
  97. var dbVersion = new DatabaseVersion() { Version = "6.30b" };
  98. settings.Contents = Serialization.Serialize(dbVersion);
  99. DbFactory.NewProvider(Logger.Main).Save(settings);
  100. return dbVersion;
  101. }
  102. public static VersionNumber GetDatabaseVersion()
  103. {
  104. var dbVersion = GetVersionSettings();
  105. return VersionNumber.Parse(dbVersion.Version);
  106. }
  107. private static VersionNumber? _versionNumber;
  108. public static VersionNumber VersionNumber
  109. {
  110. get
  111. {
  112. _versionNumber ??= GetDatabaseVersion();
  113. return _versionNumber;
  114. }
  115. }
  116. #region MetaData
  117. private static void SaveMetadata()
  118. {
  119. var settings = new GlobalSettings
  120. {
  121. Section = nameof(DatabaseMetadata),
  122. Key = "",
  123. Contents = Serialization.Serialize(MetaData)
  124. };
  125. ProviderFactory.NewProvider(Logger.Main).Save(settings);
  126. }
  127. private static void CheckMetadata()
  128. {
  129. var result = ProviderFactory.NewProvider(Logger.Main).Query(new Filter<GlobalSettings>(x => x.Section).IsEqualTo(nameof(DatabaseMetadata)))
  130. .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
  131. var data = result is not null ? Serialization.Deserialize<DatabaseMetadata>(result.Contents) : null;
  132. if (data is null)
  133. {
  134. MetaData = new DatabaseMetadata();
  135. SaveMetadata();
  136. }
  137. else
  138. {
  139. MetaData = data;
  140. }
  141. }
  142. #endregion
  143. #region License
  144. private enum LicenseValidation
  145. {
  146. Valid,
  147. Missing,
  148. Expired,
  149. Corrupt,
  150. Tampered
  151. }
  152. private static LicenseValidation CheckLicenseValidity(out DateTime expiry)
  153. {
  154. var provider = ProviderFactory.NewProvider(Logger.New());
  155. expiry = DateTime.MinValue;
  156. var license = provider.Load<License>().FirstOrDefault();
  157. if (license is null)
  158. return LicenseValidation.Missing;
  159. if (!LicenseUtils.TryDecryptLicense(license.Data, out var licenseData, out var error))
  160. return LicenseValidation.Corrupt;
  161. if (!LicenseUtils.ValidateMacAddresses(licenseData.Addresses))
  162. return LicenseValidation.Tampered;
  163. var userTrackingItems = provider.Query(
  164. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  165. Columns.None<UserTracking>().Add(x => x.ID)
  166. , log: false
  167. ).Rows
  168. .Select(r => r.Get<UserTracking, Guid>(c => c.ID))
  169. .ToArray();
  170. foreach(var item in licenseData.UserTrackingItems)
  171. {
  172. if (!userTrackingItems.Contains(item))
  173. return LicenseValidation.Tampered;
  174. }
  175. expiry = licenseData.Expiry;
  176. if (licenseData.Expiry < DateTime.Now)
  177. return LicenseValidation.Expired;
  178. return LicenseValidation.Valid;
  179. }
  180. private static int _expiredLicenseCounter = 0;
  181. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  182. private static bool _readOnly;
  183. public static bool IsReadOnly { get => _readOnly; }
  184. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  185. private static void LogRenew(string message)
  186. {
  187. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  188. }
  189. public static void LogReadOnly()
  190. {
  191. LogImportant($"Your database is in read-only mode; please renew your license to enable database updates.");
  192. }
  193. private static void LogLicenseExpiry(DateTime expiry)
  194. {
  195. if (expiry.Date == DateTime.Today)
  196. {
  197. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  198. return;
  199. }
  200. var diffInDays = (expiry - DateTime.Now).TotalDays;
  201. if(diffInDays < 1)
  202. {
  203. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  204. }
  205. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  206. {
  207. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  208. _expiredLicenseCounter = 0;
  209. }
  210. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  211. {
  212. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  213. _expiredLicenseCounter = 0;
  214. }
  215. ++_expiredLicenseCounter;
  216. }
  217. private static void BeginReadOnly()
  218. {
  219. if (!IsReadOnly)
  220. {
  221. LogImportant(
  222. "Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  223. _readOnly = true;
  224. }
  225. }
  226. private static void EndReadOnly()
  227. {
  228. if (IsReadOnly)
  229. {
  230. LogImportant("Valid license found; the database is no longer read-only.");
  231. _readOnly = false;
  232. }
  233. }
  234. private static void BeginLicenseCheckTimer()
  235. {
  236. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  237. LicenseTimer.Start();
  238. }
  239. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  240. {
  241. AssertLicense();
  242. }
  243. public static void AssertLicense()
  244. {
  245. if (String.Equals(CoreUtils.GetVersion(),"???"))
  246. return;
  247. var result = CheckLicenseValidity(out DateTime expiry);
  248. switch (result)
  249. {
  250. case LicenseValidation.Valid:
  251. LogLicenseExpiry(expiry);
  252. EndReadOnly();
  253. break;
  254. case LicenseValidation.Missing:
  255. LogImportant("Database is unlicensed!");
  256. BeginReadOnly();
  257. break;
  258. case LicenseValidation.Expired:
  259. LogImportant("Database license has expired!");
  260. BeginReadOnly();
  261. break;
  262. case LicenseValidation.Corrupt:
  263. LogImportant("Database license is corrupt - you will need to renew your license.");
  264. BeginReadOnly();
  265. break;
  266. case LicenseValidation.Tampered:
  267. LogImportant("Database license has been tampered with - you will need to renew your license.");
  268. BeginReadOnly();
  269. break;
  270. }
  271. }
  272. #endregion
  273. #region Logging
  274. private static void LogInfo(string message)
  275. {
  276. Logger.Send(LogType.Information, "", message);
  277. }
  278. private static void LogImportant(string message)
  279. {
  280. Logger.Send(LogType.Important, "", message);
  281. }
  282. private static void LogError(string message)
  283. {
  284. Logger.Send(LogType.Error, "", message);
  285. }
  286. #endregion
  287. public static void InitStores()
  288. {
  289. foreach (var storetype in stores)
  290. {
  291. var store = (Activator.CreateInstance(storetype) as IStore)!;
  292. store.Provider = ProviderFactory.NewProvider(Logger.Main);
  293. store.Logger = Logger.Main;
  294. store.Init();
  295. }
  296. }
  297. public static IStore FindStore(Type type, Guid userguid, string userid, Platform platform, string version, Logger logger)
  298. {
  299. var defType = DefaultStore.MakeGenericType(type);
  300. Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  301. var store = (Activator.CreateInstance(subType ?? defType) as IStore)!;
  302. store.Provider = ProviderFactory.NewProvider(logger);
  303. store.UserGuid = userguid;
  304. store.UserID = userid;
  305. store.Platform = platform;
  306. store.Version = version;
  307. store.Logger = logger;
  308. return store;
  309. }
  310. public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, Platform platform, string version, Logger logger)
  311. where TEntity : Entity, new()
  312. {
  313. return (FindStore(typeof(TEntity), userguid, userid, platform, version, logger) as IStore<TEntity>)!;
  314. }
  315. private static CoreTable DoQueryMultipleQuery<TEntity>(
  316. IQueryDef query,
  317. Guid userguid, string userid, Platform platform, string version, Logger logger)
  318. where TEntity : Entity, new()
  319. {
  320. var store = FindStore<TEntity>(userguid, userid, platform, version, logger);
  321. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  322. }
  323. public static Dictionary<string, CoreTable> QueryMultiple(
  324. Dictionary<string, IQueryDef> queries,
  325. Guid userguid, string userid, Platform platform, string version, Logger logger)
  326. {
  327. var result = new Dictionary<string, CoreTable>();
  328. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  329. var tasks = new List<Task>();
  330. foreach (var item in queries)
  331. tasks.Add(Task.Run(() =>
  332. {
  333. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(ProviderFactory, new object[]
  334. {
  335. item.Value,
  336. userguid, userid, platform, version, logger
  337. }) as CoreTable)!;
  338. }));
  339. Task.WaitAll(tasks.ToArray());
  340. return result;
  341. }
  342. #region Supported Types
  343. private class ModuleConfiguration : Dictionary<string, bool>, ILocalConfigurationSettings
  344. {
  345. }
  346. private static Type[]? _dbtypes;
  347. public static IEnumerable<string> SupportedTypes()
  348. {
  349. _dbtypes ??= LoadSupportedTypes();
  350. return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));
  351. }
  352. private static Type[] LoadSupportedTypes()
  353. {
  354. var result = new List<Type>();
  355. var path = ProviderFactory.URL.ToLower();
  356. var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();
  357. var bChanged = false;
  358. foreach (var type in Entities)
  359. {
  360. var key = type.EntityName();
  361. if (config.TryGetValue(key, out bool value))
  362. {
  363. if (value)
  364. //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));
  365. result.Add(type);
  366. else
  367. Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));
  368. }
  369. else
  370. {
  371. //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));
  372. config[key] = true;
  373. result.Add(type);
  374. bChanged = true;
  375. }
  376. }
  377. if (bChanged)
  378. new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);
  379. return result.ToArray();
  380. }
  381. public static bool IsSupported<T>() where T : Entity
  382. {
  383. _dbtypes ??= LoadSupportedTypes();
  384. return _dbtypes.Contains(typeof(T));
  385. }
  386. #endregion
  387. //public static void OpenSession(bool write)
  388. //{
  389. // Provider.OpenSession(write);
  390. //}
  391. //public static void CloseSession()
  392. //{
  393. // Provider.CloseSession();
  394. //}
  395. #region Private Methods
  396. public static void LoadScripts()
  397. {
  398. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  399. LoadedScripts.Clear();
  400. var scripts = ProviderFactory.NewProvider(Logger.Main).Load(
  401. new Filter<Script>
  402. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  403. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  404. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  405. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  406. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  407. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  408. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  409. );
  410. foreach (var script in scripts)
  411. {
  412. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  413. var doc = new ScriptDocument(script.Code);
  414. if (doc.Compile())
  415. {
  416. Logger.Send(LogType.Information, "",
  417. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  418. LoadedScripts[key] = doc;
  419. }
  420. else
  421. {
  422. Logger.Send(LogType.Error, "",
  423. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  424. }
  425. }
  426. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  427. }
  428. //private static Type[] entities = null;
  429. //private static void SetEntityTypes(Type[] types)
  430. //{
  431. // foreach (Type type in types)
  432. // {
  433. // if (!type.IsSubclassOf(typeof(Entity)))
  434. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  435. // }
  436. // entities = types;
  437. //}
  438. private static Type[] stores = { };
  439. private static void SetStoreTypes(Type[] types)
  440. {
  441. types = types.Where(
  442. myType => myType.IsClass
  443. && !myType.IsAbstract
  444. && !myType.IsGenericType).ToArray();
  445. foreach (var type in types)
  446. if (!type.GetInterfaces().Contains(typeof(IStore)))
  447. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  448. stores = types;
  449. }
  450. #endregion
  451. }