DbFactory.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. using System.Composition;
  2. using System.Diagnostics.CodeAnalysis;
  3. using System.Globalization;
  4. using System.Reflection;
  5. using InABox.Clients;
  6. using InABox.Configuration;
  7. using InABox.Core;
  8. using InABox.Scripting;
  9. using Microsoft.CodeAnalysis.CSharp;
  10. namespace InABox.Database
  11. {
  12. public static class DbFactory
  13. {
  14. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  15. private static string _deviceid = "";
  16. private static IProvider? _provider;
  17. public static IProvider Provider
  18. {
  19. get => _provider ?? throw new Exception("Provider is not set");
  20. set => _provider = value;
  21. }
  22. public static bool IsProviderSet => _provider is not null;
  23. public static string? ColorScheme { get; set; }
  24. public static byte[]? Logo { get; set; }
  25. //public static Type[] Entities { get { return entities; } set { SetEntityTypes(value); } }
  26. public static IEnumerable<Type> Entities
  27. {
  28. get { return CoreUtils.Entities.Where(x => x.GetInterfaces().Contains(typeof(IPersistent))); }
  29. }
  30. public static Type[] Stores
  31. {
  32. get => stores;
  33. set => SetStoreTypes(value);
  34. }
  35. public static DateTime Expiry { get; set; }
  36. public static void Start(string deviceid)
  37. {
  38. CoreUtils.CheckLicensing();
  39. _deviceid = deviceid;
  40. var status = ValidateSchema();
  41. if (status.Equals(SchemaStatus.New))
  42. try
  43. {
  44. Provider.CreateSchema(ConsolidatedObjectModel().ToArray());
  45. SaveSchema();
  46. }
  47. catch (Exception err)
  48. {
  49. throw new Exception(string.Format("Unable to Create Schema\n\n{0}", err.Message));
  50. }
  51. else if (status.Equals(SchemaStatus.Changed))
  52. try
  53. {
  54. Provider.UpgradeSchema(ConsolidatedObjectModel().ToArray());
  55. SaveSchema();
  56. }
  57. catch (Exception err)
  58. {
  59. throw new Exception(string.Format("Unable to Update Schema\n\n{0}", err.Message));
  60. }
  61. // Start the provider
  62. Provider.Types = ConsolidatedObjectModel();
  63. Provider.OnLog += LogMessage;
  64. Provider.Start();
  65. if (!DataUpdater.MigrateDatabase())
  66. {
  67. throw new Exception("Database migration failed. Aborting startup");
  68. }
  69. //Load up your custom properties here!
  70. // Can't use clients (b/c were inside the database layer already
  71. // but we can simply access the store directly :-)
  72. //CustomProperty[] props = FindStore<CustomProperty>("", "", "", "").Load(new Filter<CustomProperty>(x=>x.ID).IsNotEqualTo(Guid.Empty),null);
  73. var props = Provider.Query<CustomProperty>().Rows.Select(x => x.ToObject<CustomProperty>()).ToArray();
  74. DatabaseSchema.Load(props);
  75. AssertLicense();
  76. BeginLicenseCheckTimer();
  77. InitStores();
  78. LoadScripts();
  79. }
  80. #region License
  81. private enum LicenseValidation
  82. {
  83. Valid,
  84. Missing,
  85. Expired,
  86. Corrupt,
  87. Tampered
  88. }
  89. private static LicenseValidation CheckLicenseValidity(out License? license, out LicenseData? licenseData)
  90. {
  91. license = Provider.Load<License>().FirstOrDefault();
  92. if (license is null)
  93. {
  94. licenseData = null;
  95. return LicenseValidation.Missing;
  96. }
  97. if (!LicenseUtils.TryDecryptLicense(license.Data, out licenseData, out var error))
  98. return LicenseValidation.Corrupt;
  99. if (licenseData.Expiry < DateTime.Now)
  100. return LicenseValidation.Expired;
  101. var userTrackingItems = Provider.Query(
  102. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  103. new Columns<UserTracking>(x => x.ID), log: false).Rows.Select(x => x.Get<UserTracking, Guid>(x => x.ID));
  104. foreach(var item in licenseData.UserTrackingItems)
  105. {
  106. if (!userTrackingItems.Contains(item))
  107. {
  108. return LicenseValidation.Tampered;
  109. }
  110. }
  111. return LicenseValidation.Valid;
  112. }
  113. private static int _expiredLicenseCounter = 0;
  114. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  115. private static bool _readOnly;
  116. public static bool IsReadOnly { get => _readOnly; }
  117. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  118. private static void LogRenew(string message)
  119. {
  120. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  121. }
  122. private static void LogLicenseExpiry(DateTime expiry)
  123. {
  124. if (expiry.Date == DateTime.Today)
  125. {
  126. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  127. return;
  128. }
  129. var diffInDays = (expiry - DateTime.Now).TotalDays;
  130. if(diffInDays < 1)
  131. {
  132. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  133. }
  134. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  135. {
  136. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  137. _expiredLicenseCounter = 0;
  138. }
  139. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  140. {
  141. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  142. _expiredLicenseCounter = 0;
  143. }
  144. ++_expiredLicenseCounter;
  145. }
  146. public static void LogReadOnly()
  147. {
  148. LogError("Database is read-only because your license is invalid!");
  149. }
  150. private static void BeginReadOnly()
  151. {
  152. LogImportant("Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  153. _readOnly = true;
  154. }
  155. private static void EndReadOnly()
  156. {
  157. LogImportant("Valid license found; the database is no longer read-only.");
  158. _readOnly = false;
  159. }
  160. private static void BeginLicenseCheckTimer()
  161. {
  162. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  163. LicenseTimer.Start();
  164. }
  165. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  166. {
  167. AssertLicense();
  168. }
  169. private static Random LicenseIDGenerate = new Random();
  170. private static void UpdateValidLicense(License license, LicenseData licenseData)
  171. {
  172. var ids = Provider.Query(
  173. new Filter<UserTracking>(x => x.Created).IsGreaterThanOrEqualTo(licenseData.LastRenewal),
  174. new Columns<UserTracking>(x => x.ID), log: false);
  175. var newIDList = new List<Guid>();
  176. if(ids.Rows.Count > 0)
  177. {
  178. for (int i = 0; i < 10; i++)
  179. {
  180. newIDList.Add(ids.Rows[LicenseIDGenerate.Next(0, ids.Rows.Count)].Get<UserTracking, Guid>(x => x.ID));
  181. }
  182. }
  183. licenseData.UserTrackingItems = newIDList.ToArray();
  184. if(LicenseUtils.TryEncryptLicense(licenseData, out var newData, out var error))
  185. {
  186. license.Data = newData;
  187. Provider.Save(license);
  188. }
  189. }
  190. private static void AssertLicense()
  191. {
  192. var result = CheckLicenseValidity(out var license, out var licenseData);
  193. if (IsReadOnly)
  194. {
  195. if(result == LicenseValidation.Valid)
  196. {
  197. EndReadOnly();
  198. }
  199. return;
  200. }
  201. // TODO: Switch to real system
  202. if(result != LicenseValidation.Valid)
  203. {
  204. var newLicense = LicenseUtils.GenerateNewLicense();
  205. if (LicenseUtils.TryEncryptLicense(newLicense, out var newData, out var error))
  206. {
  207. if (license == null)
  208. license = new License();
  209. license.Data = newData;
  210. Provider.Save(license);
  211. }
  212. else
  213. {
  214. Logger.Send(LogType.Error, "", $"Error updating license: {error}");
  215. }
  216. return;
  217. }
  218. else
  219. {
  220. return;
  221. }
  222. switch (result)
  223. {
  224. case LicenseValidation.Valid:
  225. LogLicenseExpiry(licenseData!.Expiry);
  226. UpdateValidLicense(license, licenseData);
  227. break;
  228. case LicenseValidation.Missing:
  229. LogImportant("Database is unlicensed!");
  230. BeginReadOnly();
  231. break;
  232. case LicenseValidation.Expired:
  233. LogImportant("Database license has expired!");
  234. BeginReadOnly();
  235. break;
  236. case LicenseValidation.Corrupt:
  237. LogImportant("Database license is corrupt - you will need to renew your license.");
  238. BeginReadOnly();
  239. break;
  240. case LicenseValidation.Tampered:
  241. LogImportant("Database license has been tampered with - you will need to renew your license.");
  242. BeginReadOnly();
  243. break;
  244. }
  245. }
  246. #endregion
  247. #region Logging
  248. private static void LogMessage(LogType type, string message)
  249. {
  250. Logger.Send(type, "", message);
  251. }
  252. private static void LogInfo(string message)
  253. {
  254. Logger.Send(LogType.Information, "", message);
  255. }
  256. private static void LogImportant(string message)
  257. {
  258. Logger.Send(LogType.Important, "", message);
  259. }
  260. private static void LogError(string message)
  261. {
  262. Logger.Send(LogType.Error, "", message);
  263. }
  264. #endregion
  265. public static void InitStores()
  266. {
  267. foreach (var storetype in stores)
  268. {
  269. var store = Activator.CreateInstance(storetype) as IStore;
  270. store.Provider = Provider;
  271. store.Init();
  272. }
  273. }
  274. public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, string platform, string version)
  275. where TEntity : Entity, new()
  276. {
  277. var defType = typeof(Store<>).MakeGenericType(typeof(TEntity));
  278. Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  279. var store = (Store<TEntity>)Activator.CreateInstance(subType ?? defType)!;
  280. store.Provider = Provider;
  281. store.UserGuid = userguid;
  282. store.UserID = userid;
  283. store.Platform = platform;
  284. store.Version = version;
  285. return store;
  286. }
  287. private static CoreTable DoQueryMultipleQuery<TEntity>(
  288. IQueryDef query,
  289. Guid userguid, string userid, string platform, string version)
  290. where TEntity : Entity, new()
  291. {
  292. var store = FindStore<TEntity>(userguid, userid, platform, version);
  293. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  294. }
  295. public static Dictionary<string, CoreTable> QueryMultiple(
  296. Dictionary<string, IQueryDef> queries,
  297. Guid userguid, string userid, string platform, string version)
  298. {
  299. var result = new Dictionary<string, CoreTable>();
  300. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  301. var tasks = new List<Task>();
  302. foreach (var item in queries)
  303. tasks.Add(Task.Run(() =>
  304. {
  305. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(Provider, new object[]
  306. {
  307. item.Value,
  308. userguid, userid, platform, version
  309. }) as CoreTable)!;
  310. }));
  311. Task.WaitAll(tasks.ToArray());
  312. return result;
  313. }
  314. #region Supported Types
  315. private class ModuleConfiguration : Dictionary<string, bool>, LocalConfigurationSettings
  316. {
  317. }
  318. private static Type[]? _dbtypes;
  319. public static IEnumerable<string> SupportedTypes()
  320. {
  321. _dbtypes ??= LoadSupportedTypes();
  322. return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));
  323. }
  324. private static Type[] LoadSupportedTypes()
  325. {
  326. var result = new List<Type>();
  327. var path = Provider.URL.ToLower();
  328. var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();
  329. var bChanged = false;
  330. foreach (var type in Entities)
  331. {
  332. var key = type.EntityName();
  333. if (config.ContainsKey(key))
  334. {
  335. if (config[key])
  336. //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));
  337. result.Add(type);
  338. else
  339. Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));
  340. }
  341. else
  342. {
  343. //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));
  344. config[key] = true;
  345. result.Add(type);
  346. bChanged = true;
  347. }
  348. }
  349. if (bChanged)
  350. new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);
  351. return result.ToArray();
  352. }
  353. public static bool IsSupported<T>() where T : Entity
  354. {
  355. _dbtypes ??= LoadSupportedTypes();
  356. return _dbtypes.Contains(typeof(T));
  357. }
  358. #endregion
  359. //public static void OpenSession(bool write)
  360. //{
  361. // Provider.OpenSession(write);
  362. //}
  363. //public static void CloseSession()
  364. //{
  365. // Provider.CloseSession();
  366. //}
  367. #region Private Methods
  368. public static void LoadScripts()
  369. {
  370. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  371. LoadedScripts.Clear();
  372. var scripts = Provider.Load(
  373. new Filter<Script>
  374. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  375. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  376. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  377. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  378. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  379. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  380. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  381. );
  382. foreach (var script in scripts)
  383. {
  384. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  385. var doc = new ScriptDocument(script.Code);
  386. if (doc.Compile())
  387. {
  388. Logger.Send(LogType.Information, "",
  389. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  390. LoadedScripts[key] = doc;
  391. }
  392. else
  393. {
  394. Logger.Send(LogType.Error, "",
  395. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  396. }
  397. }
  398. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  399. }
  400. //private static Type[] entities = null;
  401. //private static void SetEntityTypes(Type[] types)
  402. //{
  403. // foreach (Type type in types)
  404. // {
  405. // if (!type.IsSubclassOf(typeof(Entity)))
  406. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  407. // }
  408. // entities = types;
  409. //}
  410. private static Type[] stores = { };
  411. private static void SetStoreTypes(Type[] types)
  412. {
  413. types = types.Where(
  414. myType => myType.IsClass
  415. && !myType.IsAbstract
  416. && !myType.IsGenericType).ToArray();
  417. foreach (var type in types)
  418. if (!type.GetInterfaces().Contains(typeof(IStore)))
  419. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  420. stores = types;
  421. }
  422. private static Type[] ConsolidatedObjectModel()
  423. {
  424. // Add the core types from InABox.Core
  425. var types = new List<Type>();
  426. //var coreTypes = CoreUtils.TypeList(
  427. // new Assembly[] { typeof(Entity).Assembly },
  428. // myType =>
  429. // myType.IsClass
  430. // && !myType.IsAbstract
  431. // && !myType.IsGenericType
  432. // && myType.IsSubclassOf(typeof(Entity))
  433. // && myType.GetInterfaces().Contains(typeof(IRemotable))
  434. //);
  435. //types.AddRange(coreTypes);
  436. // Now add the end-user object model
  437. types.AddRange(Entities.Where(x =>
  438. x.GetTypeInfo().IsClass
  439. && !x.GetTypeInfo().IsGenericType
  440. && x.GetTypeInfo().IsSubclassOf(typeof(Entity))
  441. ));
  442. return types.ToArray();
  443. }
  444. private enum SchemaStatus
  445. {
  446. New,
  447. Changed,
  448. Validated
  449. }
  450. private static Dictionary<string, Type> GetSchema()
  451. {
  452. var model = new Dictionary<string, Type>();
  453. var objectmodel = ConsolidatedObjectModel();
  454. foreach (var type in objectmodel)
  455. {
  456. Dictionary<string, Type> thismodel = CoreUtils.PropertyList(type, x => true, true);
  457. foreach (var key in thismodel.Keys)
  458. model[type.Name + "." + key] = thismodel[key];
  459. }
  460. return model;
  461. //return Serialization.Serialize(model, Formatting.Indented);
  462. }
  463. private static SchemaStatus ValidateSchema()
  464. {
  465. var db_schema = Provider.GetSchema();
  466. if (db_schema.Count() == 0)
  467. return SchemaStatus.New;
  468. var mdl_json = Serialization.Serialize(GetSchema());
  469. var db_json = Serialization.Serialize(db_schema);
  470. return mdl_json.Equals(db_json) ? SchemaStatus.Validated : SchemaStatus.Changed;
  471. }
  472. private static void SaveSchema()
  473. {
  474. Provider.SaveSchema(GetSchema());
  475. }
  476. #endregion
  477. }
  478. }