using System.Reflection; using FluentResults; using InABox.Clients; using InABox.Configuration; using InABox.Core; using InABox.Scripting; namespace InABox.Database; public class DatabaseMetadata : BaseObject, IGlobalConfigurationSettings { public Guid DatabaseID { get; set; } = Guid.NewGuid(); } public class DbLockedException : Exception { public DbLockedException(): base("Database is read-only due to PRS license expiry.") { } } public static class DbFactory { public static Dictionary LoadedScripts = new(); private static DatabaseMetadata MetaData { get; set; } = new(); public static Guid ID { get => MetaData.DatabaseID; set { MetaData.DatabaseID = value; SaveMetadata(); } } private static IProviderFactory? _providerFactory; public static IProviderFactory ProviderFactory { get => _providerFactory ?? throw new Exception("Provider is not set"); set => _providerFactory = value; } public static bool IsProviderSet => _providerFactory is not null; public static string? ColorScheme { get; set; } public static byte[]? Logo { get; set; } // See notes in Request.DatabaseInfo class // Once RPC transport is stable, these settings need // to be removed public static int RestPort { get; set; } public static int RPCPort { get; set; } /// /// Return every entity in . /// public static IEnumerable Entities => CoreUtils.Entities.Where(x => x.HasInterface()); public static Type[] Stores { get => stores; set => SetStoreTypes(value); } public static DateTime Expiry { get; set; } public static IProvider NewProvider(Logger logger) => ProviderFactory.NewProvider(logger); public static void Start() { CoreUtils.CheckLicensing(); // Start the provider ProviderFactory.Types = Entities.Where(x => x.IsClass && !x.IsGenericType && x.IsSubclassOf(typeof(Entity)) ).ToArray(); ProviderFactory.Start(); CheckMetadata(); if (!DataUpdater.MigrateDatabase()) { throw new Exception("Database migration failed. Aborting startup"); } //Load up your custom properties here! // Can't use clients (b/c we're inside the database layer already // but we can simply access the store directly :-) //CustomProperty[] props = FindStore("", "", "", "").Load(new Filter(x=>x.ID).IsNotEqualTo(Guid.Empty),null); var props = ProviderFactory.NewProvider(Logger.Main).Query().ToArray(); DatabaseSchema.Load(props); AssertLicense(); BeginLicenseCheckTimer(); InitStores(); LoadScripts(); } #region MetaData private static void SaveMetadata() { var settings = new GlobalSettings { Section = nameof(DatabaseMetadata), Key = "", Contents = Serialization.Serialize(MetaData) }; ProviderFactory.NewProvider(Logger.Main).Save(settings); } private static void CheckMetadata() { var result = ProviderFactory.NewProvider(Logger.Main).Query(new Filter(x => x.Section).IsEqualTo(nameof(DatabaseMetadata))) .Rows.FirstOrDefault()?.ToObject(); var data = result is not null ? Serialization.Deserialize(result.Contents) : null; if (data is null) { MetaData = new DatabaseMetadata(); SaveMetadata(); } else { MetaData = data; } } #endregion #region License private enum LicenseValidation { Valid, Missing, Expired, Corrupt, Tampered } private static LicenseValidation CheckLicenseValidity(out DateTime expiry) { var provider = ProviderFactory.NewProvider(Logger.New()); expiry = DateTime.MinValue; var license = provider.Load().FirstOrDefault(); if (license is null) return LicenseValidation.Missing; if (!LicenseUtils.TryDecryptLicense(license.Data, out var licenseData, out var error)) return LicenseValidation.Corrupt; if (!LicenseUtils.ValidateMacAddresses(licenseData.Addresses)) return LicenseValidation.Tampered; var userTrackingItems = provider.Query( new Filter(x => x.ID).InList(licenseData.UserTrackingItems), Columns.None().Add(x => x.ID) , log: false ).Rows .Select(r => r.Get(c => c.ID)) .ToArray(); foreach(var item in licenseData.UserTrackingItems) { if (!userTrackingItems.Contains(item)) return LicenseValidation.Tampered; } expiry = licenseData.Expiry; if (licenseData.Expiry < DateTime.Now) return LicenseValidation.Expired; return LicenseValidation.Valid; } private static int _expiredLicenseCounter = 0; private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10); private static bool _readOnly; public static bool IsReadOnly { get => _readOnly; } private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true }; private static void LogRenew(string message) { LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal."); } public static void LogReadOnly() { LogImportant($"Your database is in read-only mode; please renew your license to enable database updates."); } private static void LogLicenseExpiry(DateTime expiry) { if (expiry.Date == DateTime.Today) { LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!"); return; } var diffInDays = (expiry - DateTime.Now).TotalDays; if(diffInDays < 1) { LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}."); } else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1) { LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}."); _expiredLicenseCounter = 0; } else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2) { LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}."); _expiredLicenseCounter = 0; } ++_expiredLicenseCounter; } private static void BeginReadOnly() { if (!IsReadOnly) { LogImportant( "Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal."); _readOnly = true; } } private static void EndReadOnly() { if (IsReadOnly) { LogImportant("Valid license found; the database is no longer read-only."); _readOnly = false; } } private static void BeginLicenseCheckTimer() { LicenseTimer.Elapsed += LicenseTimer_Elapsed; LicenseTimer.Start(); } private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e) { AssertLicense(); } public static void AssertLicense() { var result = CheckLicenseValidity(out DateTime expiry); switch (result) { case LicenseValidation.Valid: LogLicenseExpiry(expiry); EndReadOnly(); break; case LicenseValidation.Missing: LogImportant("Database is unlicensed!"); BeginReadOnly(); break; case LicenseValidation.Expired: LogImportant("Database license has expired!"); BeginReadOnly(); break; case LicenseValidation.Corrupt: LogImportant("Database license is corrupt - you will need to renew your license."); BeginReadOnly(); break; case LicenseValidation.Tampered: LogImportant("Database license has been tampered with - you will need to renew your license."); BeginReadOnly(); break; } } #endregion #region Logging private static void LogInfo(string message) { Logger.Send(LogType.Information, "", message); } private static void LogImportant(string message) { Logger.Send(LogType.Important, "", message); } private static void LogError(string message) { Logger.Send(LogType.Error, "", message); } #endregion public static void InitStores() { foreach (var storetype in stores) { var store = (Activator.CreateInstance(storetype) as IStore)!; store.Provider = ProviderFactory.NewProvider(Logger.Main); store.Logger = Logger.Main; store.Init(); } } public static IStore FindStore(Type type, Guid userguid, string userid, Platform platform, string version, Logger logger) { var defType = typeof(Store<>).MakeGenericType(type); Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault(); var store = (Activator.CreateInstance(subType ?? defType) as IStore)!; store.Provider = ProviderFactory.NewProvider(logger); store.UserGuid = userguid; store.UserID = userid; store.Platform = platform; store.Version = version; store.Logger = logger; return store; } public static IStore FindStore(Guid userguid, string userid, Platform platform, string version, Logger logger) where TEntity : Entity, new() { return (FindStore(typeof(TEntity), userguid, userid, platform, version, logger) as IStore)!; } private static CoreTable DoQueryMultipleQuery( IQueryDef query, Guid userguid, string userid, Platform platform, string version, Logger logger) where TEntity : Entity, new() { var store = FindStore(userguid, userid, platform, version, logger); return store.Query(query.Filter as Filter, query.Columns as Columns, query.SortOrder as SortOrder); } public static Dictionary QueryMultiple( Dictionary queries, Guid userguid, string userid, Platform platform, string version, Logger logger) { var result = new Dictionary(); var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!; var tasks = new List(); foreach (var item in queries) tasks.Add(Task.Run(() => { result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(ProviderFactory, new object[] { item.Value, userguid, userid, platform, version, logger }) as CoreTable)!; })); Task.WaitAll(tasks.ToArray()); return result; } #region Supported Types private class ModuleConfiguration : Dictionary, ILocalConfigurationSettings { } private static Type[]? _dbtypes; public static IEnumerable SupportedTypes() { _dbtypes ??= LoadSupportedTypes(); return _dbtypes.Select(x => x.EntityName().Replace(".", "_")); } private static Type[] LoadSupportedTypes() { var result = new List(); var path = ProviderFactory.URL.ToLower(); var config = new LocalConfiguration(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load(); var bChanged = false; foreach (var type in Entities) { var key = type.EntityName(); if (config.TryGetValue(key, out bool value)) { if (value) //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key)); result.Add(type); else Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key)); } else { //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key)); config[key] = true; result.Add(type); bChanged = true; } } if (bChanged) new LocalConfiguration(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config); return result.ToArray(); } public static bool IsSupported() where T : Entity { _dbtypes ??= LoadSupportedTypes(); return _dbtypes.Contains(typeof(T)); } #endregion //public static void OpenSession(bool write) //{ // Provider.OpenSession(write); //} //public static void CloseSession() //{ // Provider.CloseSession(); //} #region Private Methods public static void LoadScripts() { Logger.Send(LogType.Information, "", "Loading Script Cache..."); LoadedScripts.Clear(); var scripts = ProviderFactory.NewProvider(Logger.Main).Load( new Filter