DbFactory.cs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. using System.Reflection;
  2. using FluentResults;
  3. using InABox.Clients;
  4. using InABox.Configuration;
  5. using InABox.Core;
  6. using InABox.Scripting;
  7. namespace InABox.Database;
  8. public class DatabaseMetadata : BaseObject, IGlobalConfigurationSettings
  9. {
  10. public Guid DatabaseID { get; set; } = Guid.NewGuid();
  11. }
  12. public class DbLockedException : Exception
  13. {
  14. public DbLockedException(): base("Database is read-only due to PRS license expiry.") { }
  15. }
  16. public static class DbFactory
  17. {
  18. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  19. private static DatabaseMetadata MetaData { get; set; } = new();
  20. public static Guid ID
  21. {
  22. get => MetaData.DatabaseID;
  23. set
  24. {
  25. MetaData.DatabaseID = value;
  26. SaveMetadata();
  27. }
  28. }
  29. private static IProviderFactory? _providerFactory;
  30. public static IProviderFactory ProviderFactory
  31. {
  32. get => _providerFactory ?? throw new Exception("Provider is not set");
  33. set => _providerFactory = value;
  34. }
  35. public static bool IsProviderSet => _providerFactory is not null;
  36. public static string? ColorScheme { get; set; }
  37. public static byte[]? Logo { get; set; }
  38. public static Type DefaultStore { get; set; } = typeof(Store<>);
  39. // See notes in Request.DatabaseInfo class
  40. // Once RPC transport is stable, these settings need
  41. // to be removed
  42. public static int RestPort { get; set; }
  43. public static int RPCPort { get; set; }
  44. /// <summary>
  45. /// Return every <see cref="IPersistent"/> entity in <see cref="CoreUtils.Entities"/>.
  46. /// </summary>
  47. public static IEnumerable<Type> Entities => CoreUtils.Entities.Where(x => x.HasInterface<IPersistent>());
  48. public static Type[] Stores
  49. {
  50. get => stores;
  51. set => SetStoreTypes(value);
  52. }
  53. public static DateTime Expiry { get; set; }
  54. public static IProvider NewProvider(Logger logger) => ProviderFactory.NewProvider(logger);
  55. public static void Start(Type[]? types = null)
  56. {
  57. CoreUtils.CheckLicensing();
  58. if(types is not null)
  59. {
  60. ProviderFactory.Types = types.Concat(CoreUtils.IterateTypes(typeof(CoreUtils).Assembly).Where(x => !x.IsAbstract))
  61. .Where(x => x.IsClass && !x.IsGenericType && x.IsSubclassOf(typeof(Entity)))
  62. .ToArray();
  63. }
  64. else
  65. {
  66. ProviderFactory.Types = Entities.Where(x =>
  67. x.IsClass
  68. && !x.IsGenericType
  69. && x.IsSubclassOf(typeof(Entity))
  70. ).ToArray();
  71. }
  72. // Start the provider
  73. ProviderFactory.Start();
  74. CheckMetadata();
  75. if (!DataUpdater.MigrateDatabase())
  76. {
  77. throw new Exception("Database migration failed. Aborting startup");
  78. }
  79. AssertLicense();
  80. BeginLicenseCheckTimer();
  81. InitStores();
  82. LoadScripts();
  83. }
  84. public static DatabaseVersion GetVersionSettings()
  85. {
  86. var provider = ProviderFactory.NewProvider(Logger.New());
  87. var result = provider.Query(new Filter<GlobalSettings>(x => x.Section).IsEqualTo(nameof(DatabaseVersion)))
  88. .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
  89. if(result != null)
  90. {
  91. return Serialization.Deserialize<DatabaseVersion>(result.Contents);
  92. }
  93. var settings = new GlobalSettings() { Section = nameof(DatabaseVersion), Key = "" };
  94. var dbVersion = new DatabaseVersion() { Version = "6.30b" };
  95. settings.Contents = Serialization.Serialize(dbVersion);
  96. DbFactory.NewProvider(Logger.Main).Save(settings);
  97. return dbVersion;
  98. }
  99. public static VersionNumber GetDatabaseVersion()
  100. {
  101. var dbVersion = GetVersionSettings();
  102. return VersionNumber.Parse(dbVersion.Version);
  103. }
  104. private static VersionNumber? _versionNumber;
  105. public static VersionNumber VersionNumber
  106. {
  107. get
  108. {
  109. _versionNumber ??= GetDatabaseVersion();
  110. return _versionNumber;
  111. }
  112. }
  113. #region MetaData
  114. private static void SaveMetadata()
  115. {
  116. var settings = new GlobalSettings
  117. {
  118. Section = nameof(DatabaseMetadata),
  119. Key = "",
  120. Contents = Serialization.Serialize(MetaData)
  121. };
  122. ProviderFactory.NewProvider(Logger.Main).Save(settings);
  123. }
  124. private static void CheckMetadata()
  125. {
  126. var result = ProviderFactory.NewProvider(Logger.Main).Query(new Filter<GlobalSettings>(x => x.Section).IsEqualTo(nameof(DatabaseMetadata)))
  127. .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
  128. var data = result is not null ? Serialization.Deserialize<DatabaseMetadata>(result.Contents) : null;
  129. if (data is null)
  130. {
  131. MetaData = new DatabaseMetadata();
  132. SaveMetadata();
  133. }
  134. else
  135. {
  136. MetaData = data;
  137. }
  138. }
  139. #endregion
  140. #region License
  141. private enum LicenseValidation
  142. {
  143. Valid,
  144. Missing,
  145. Expired,
  146. Corrupt,
  147. Tampered
  148. }
  149. private static LicenseValidation CheckLicenseValidity(out DateTime expiry)
  150. {
  151. var provider = ProviderFactory.NewProvider(Logger.New());
  152. expiry = DateTime.MinValue;
  153. var license = provider.Load<License>().FirstOrDefault();
  154. if (license is null)
  155. return LicenseValidation.Missing;
  156. if (!LicenseUtils.TryDecryptLicense(license.Data, out var licenseData, out var error))
  157. return LicenseValidation.Corrupt;
  158. if (!LicenseUtils.ValidateMacAddresses(licenseData.Addresses))
  159. return LicenseValidation.Tampered;
  160. var userTrackingItems = provider.Query(
  161. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  162. Columns.None<UserTracking>().Add(x => x.ID)
  163. , log: false
  164. ).Rows
  165. .Select(r => r.Get<UserTracking, Guid>(c => c.ID))
  166. .ToArray();
  167. foreach(var item in licenseData.UserTrackingItems)
  168. {
  169. if (!userTrackingItems.Contains(item))
  170. return LicenseValidation.Tampered;
  171. }
  172. expiry = licenseData.Expiry;
  173. if (licenseData.Expiry < DateTime.Now)
  174. return LicenseValidation.Expired;
  175. return LicenseValidation.Valid;
  176. }
  177. private static int _expiredLicenseCounter = 0;
  178. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  179. private static bool _readOnly;
  180. public static bool IsReadOnly { get => _readOnly; }
  181. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  182. private static void LogRenew(string message)
  183. {
  184. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  185. }
  186. public static void LogReadOnly()
  187. {
  188. LogImportant($"Your database is in read-only mode; please renew your license to enable database updates.");
  189. }
  190. private static void LogLicenseExpiry(DateTime expiry)
  191. {
  192. if (expiry.Date == DateTime.Today)
  193. {
  194. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  195. return;
  196. }
  197. var diffInDays = (expiry - DateTime.Now).TotalDays;
  198. if(diffInDays < 1)
  199. {
  200. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  201. }
  202. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  203. {
  204. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  205. _expiredLicenseCounter = 0;
  206. }
  207. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  208. {
  209. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  210. _expiredLicenseCounter = 0;
  211. }
  212. ++_expiredLicenseCounter;
  213. }
  214. private static void BeginReadOnly()
  215. {
  216. if (!IsReadOnly)
  217. {
  218. LogImportant(
  219. "Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  220. _readOnly = true;
  221. }
  222. }
  223. private static void EndReadOnly()
  224. {
  225. if (IsReadOnly)
  226. {
  227. LogImportant("Valid license found; the database is no longer read-only.");
  228. _readOnly = false;
  229. }
  230. }
  231. private static void BeginLicenseCheckTimer()
  232. {
  233. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  234. LicenseTimer.Start();
  235. }
  236. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  237. {
  238. AssertLicense();
  239. }
  240. public static void AssertLicense()
  241. {
  242. if (String.Equals(CoreUtils.GetVersion(),"???"))
  243. return;
  244. var result = CheckLicenseValidity(out DateTime expiry);
  245. switch (result)
  246. {
  247. case LicenseValidation.Valid:
  248. LogLicenseExpiry(expiry);
  249. EndReadOnly();
  250. break;
  251. case LicenseValidation.Missing:
  252. LogImportant("Database is unlicensed!");
  253. BeginReadOnly();
  254. break;
  255. case LicenseValidation.Expired:
  256. LogImportant("Database license has expired!");
  257. BeginReadOnly();
  258. break;
  259. case LicenseValidation.Corrupt:
  260. LogImportant("Database license is corrupt - you will need to renew your license.");
  261. BeginReadOnly();
  262. break;
  263. case LicenseValidation.Tampered:
  264. LogImportant("Database license has been tampered with - you will need to renew your license.");
  265. BeginReadOnly();
  266. break;
  267. }
  268. }
  269. #endregion
  270. #region Logging
  271. private static void LogInfo(string message)
  272. {
  273. Logger.Send(LogType.Information, "", message);
  274. }
  275. private static void LogImportant(string message)
  276. {
  277. Logger.Send(LogType.Important, "", message);
  278. }
  279. private static void LogError(string message)
  280. {
  281. Logger.Send(LogType.Error, "", message);
  282. }
  283. #endregion
  284. public static void InitStores()
  285. {
  286. foreach (var storetype in stores)
  287. {
  288. var store = (Activator.CreateInstance(storetype) as IStore)!;
  289. store.Provider = ProviderFactory.NewProvider(Logger.Main);
  290. store.Logger = Logger.Main;
  291. store.Init();
  292. }
  293. }
  294. public static IStore FindStore(Type type, Guid userguid, string userid, Platform platform, string version, Logger logger)
  295. {
  296. var anyType = typeof(Store<>).MakeGenericType(type);
  297. Type? subType = Stores.Where(myType => myType.IsSubclassOf(anyType)).FirstOrDefault();
  298. var store = (Activator.CreateInstance(subType ?? DefaultStore.MakeGenericType(type)) as IStore)!;
  299. store.Provider = ProviderFactory.NewProvider(logger);
  300. store.UserGuid = userguid;
  301. store.UserID = userid;
  302. store.Platform = platform;
  303. store.Version = version;
  304. store.Logger = logger;
  305. return store;
  306. }
  307. public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, Platform platform, string version, Logger logger)
  308. where TEntity : Entity, new()
  309. {
  310. return (FindStore(typeof(TEntity), userguid, userid, platform, version, logger) as IStore<TEntity>)!;
  311. }
  312. private static CoreTable DoQueryMultipleQuery<TEntity>(
  313. IQueryDef query,
  314. Guid userguid, string userid, Platform platform, string version, Logger logger)
  315. where TEntity : Entity, new()
  316. {
  317. var store = FindStore<TEntity>(userguid, userid, platform, version, logger);
  318. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  319. }
  320. public static Dictionary<string, CoreTable> QueryMultiple(
  321. Dictionary<string, IQueryDef> queries,
  322. Guid userguid, string userid, Platform platform, string version, Logger logger)
  323. {
  324. var result = new Dictionary<string, CoreTable>();
  325. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  326. var tasks = new List<Task>();
  327. foreach (var item in queries)
  328. tasks.Add(Task.Run(() =>
  329. {
  330. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(ProviderFactory, new object[]
  331. {
  332. item.Value,
  333. userguid, userid, platform, version, logger
  334. }) as CoreTable)!;
  335. }));
  336. Task.WaitAll(tasks.ToArray());
  337. return result;
  338. }
  339. #region Supported Types
  340. private class ModuleConfiguration : Dictionary<string, bool>, ILocalConfigurationSettings
  341. {
  342. }
  343. private static Type[]? _dbtypes;
  344. public static IEnumerable<string> SupportedTypes()
  345. {
  346. _dbtypes ??= LoadSupportedTypes();
  347. return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));
  348. }
  349. private static Type[] LoadSupportedTypes()
  350. {
  351. var result = new List<Type>();
  352. var path = ProviderFactory.URL.ToLower();
  353. var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();
  354. var bChanged = false;
  355. foreach (var type in Entities)
  356. {
  357. var key = type.EntityName();
  358. if (config.TryGetValue(key, out bool value))
  359. {
  360. if (value)
  361. //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));
  362. result.Add(type);
  363. else
  364. Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));
  365. }
  366. else
  367. {
  368. //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));
  369. config[key] = true;
  370. result.Add(type);
  371. bChanged = true;
  372. }
  373. }
  374. if (bChanged)
  375. new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);
  376. return result.ToArray();
  377. }
  378. public static bool IsSupported<T>() where T : Entity
  379. {
  380. _dbtypes ??= LoadSupportedTypes();
  381. return _dbtypes.Contains(typeof(T));
  382. }
  383. #endregion
  384. //public static void OpenSession(bool write)
  385. //{
  386. // Provider.OpenSession(write);
  387. //}
  388. //public static void CloseSession()
  389. //{
  390. // Provider.CloseSession();
  391. //}
  392. #region Private Methods
  393. public static void LoadScripts()
  394. {
  395. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  396. LoadedScripts.Clear();
  397. var scripts = ProviderFactory.NewProvider(Logger.Main).Load(
  398. new Filter<Script>
  399. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  400. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  401. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  402. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  403. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  404. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  405. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  406. );
  407. foreach (var script in scripts)
  408. {
  409. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  410. var doc = new ScriptDocument(script.Code);
  411. if (doc.Compile())
  412. {
  413. Logger.Send(LogType.Information, "",
  414. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  415. LoadedScripts[key] = doc;
  416. }
  417. else
  418. {
  419. Logger.Send(LogType.Error, "",
  420. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  421. }
  422. }
  423. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  424. }
  425. //private static Type[] entities = null;
  426. //private static void SetEntityTypes(Type[] types)
  427. //{
  428. // foreach (Type type in types)
  429. // {
  430. // if (!type.IsSubclassOf(typeof(Entity)))
  431. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  432. // }
  433. // entities = types;
  434. //}
  435. private static Type[] stores = { };
  436. private static void SetStoreTypes(Type[] types)
  437. {
  438. types = types.Where(
  439. myType => myType.IsClass
  440. && !myType.IsAbstract
  441. && !myType.IsGenericType).ToArray();
  442. foreach (var type in types)
  443. if (!type.GetInterfaces().Contains(typeof(IStore)))
  444. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  445. stores = types;
  446. }
  447. #endregion
  448. }