| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557 | using System.Composition;using System.Diagnostics.CodeAnalysis;using System.Globalization;using System.Reflection;using InABox.Clients;using InABox.Configuration;using InABox.Core;using InABox.Scripting;using Microsoft.CodeAnalysis.CSharp;namespace InABox.Database{    public static class DbFactory    {        public static Dictionary<string, ScriptDocument> LoadedScripts = new();        private static string _deviceid = "";        private static IProvider? _provider;        public static IProvider Provider        {            get => _provider ?? throw new Exception("Provider is not set");            set => _provider = value;        }                public static string? ColorScheme { get; set; }        public static byte[]? Logo { get; set; }        //public static Type[] Entities { get { return entities; } set { SetEntityTypes(value); } }        public static IEnumerable<Type> Entities        {            get { return CoreUtils.Entities.Where(x => x.GetInterfaces().Contains(typeof(IPersistent))); }        }        public static Type[] Stores        {            get => stores;            set => SetStoreTypes(value);        }        public static DateTime Expiry { get; set; }        public static void Start(string deviceid)        {            CoreUtils.CheckLicensing();            _deviceid = deviceid;            var status = ValidateSchema();            if (status.Equals(SchemaStatus.New))                try                {                    Provider.CreateSchema(ConsolidatedObjectModel().ToArray());                    SaveSchema();                }                catch (Exception err)                {                    throw new Exception(string.Format("Unable to Create Schema\n\n{0}", err.Message));                }            else if (status.Equals(SchemaStatus.Changed))                try                {                    Provider.UpgradeSchema(ConsolidatedObjectModel().ToArray());                    SaveSchema();                }                catch (Exception err)                {                    throw new Exception(string.Format("Unable to Update Schema\n\n{0}", err.Message));                }            // Start the provider            Provider.Types = ConsolidatedObjectModel();            Provider.OnLog += LogMessage;            Provider.Start();            if (!DataUpdater.MigrateDatabase())            {                throw new Exception("Database migration failed. Aborting startup");            }            //Load up your custom properties here!            // Can't use clients (b/c were inside the database layer already            // but we can simply access the store directly :-)            //CustomProperty[] props = FindStore<CustomProperty>("", "", "", "").Load(new Filter<CustomProperty>(x=>x.ID).IsNotEqualTo(Guid.Empty),null);            var props = Provider.Query<CustomProperty>().Rows.Select(x => x.ToObject<CustomProperty>()).ToArray();            DatabaseSchema.Load(props);            AssertLicense();            BeginLicenseCheckTimer();            InitStores();            LoadScripts();        }        #region License        private enum LicenseValidation        {            Valid,            Missing,            Expired,            Corrupt,            Tampered        }        private static LicenseValidation CheckLicenseValidity(out License? license, out LicenseData? licenseData)        {            license = Provider.Load<License>().FirstOrDefault();            if (license is null)            {                licenseData = null;                return LicenseValidation.Missing;            }            if (!LicenseUtils.TryDecryptLicense(license.Data, out licenseData, out var error))                return LicenseValidation.Corrupt;            if (licenseData.Expiry < DateTime.Now)                return LicenseValidation.Expired;            var userTrackingItems = Provider.Query(                new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),                new Columns<UserTracking>(x => x.ID), log: false).Rows.Select(x => x.Get<UserTracking, Guid>(x => x.ID));            foreach(var item in licenseData.UserTrackingItems)            {                if (!userTrackingItems.Contains(item))                {                    return LicenseValidation.Tampered;                }            }            return LicenseValidation.Valid;        }        private static int _expiredLicenseCounter = 0;        private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);        private static bool _readOnly;        public static bool IsReadOnly { get => _readOnly; }        private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };        private static void LogRenew(string message)        {            LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prs-software.com.au/wiki/index.php/License_Renewal.");        }        private static void LogLicenseExpiry(DateTime expiry)        {            if (expiry.Date == DateTime.Today)            {                LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");                return;            }            var diffInDays = (expiry - DateTime.Now).TotalDays;            if(diffInDays < 1)            {                LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");            }            else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)            {                LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");                _expiredLicenseCounter = 0;            }            else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)            {                LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");                _expiredLicenseCounter = 0;            }            ++_expiredLicenseCounter;        }        public static void LogReadOnly()        {            LogError("Database is read-only because your license is invalid!");        }        private static void BeginReadOnly()        {            LogImportant("Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prs-software.com.au/wiki/index.php/License_Renewal.");            _readOnly = true;        }        private static void EndReadOnly()        {            LogImportant("Valid license found; the database is no longer read-only.");            _readOnly = false;        }        private static void BeginLicenseCheckTimer()        {            LicenseTimer.Elapsed += LicenseTimer_Elapsed;            LicenseTimer.Start();        }        private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)        {            AssertLicense();        }        private static Random LicenseIDGenerate = new Random();        private static void UpdateValidLicense(License license, LicenseData licenseData)        {            var ids = Provider.Query(                new Filter<UserTracking>(x => x.Created).IsGreaterThanOrEqualTo(licenseData.LastRenewal),                new Columns<UserTracking>(x => x.ID), log: false);            var newIDList = new List<Guid>();            if(ids.Rows.Count > 0)            {                for (int i = 0; i < 10; i++)                {                    newIDList.Add(ids.Rows[LicenseIDGenerate.Next(0, ids.Rows.Count)].Get<UserTracking, Guid>(x => x.ID));                }            }            licenseData.UserTrackingItems = newIDList.ToArray();            if(LicenseUtils.TryEncryptLicense(licenseData, out var newData, out var error))            {                license.Data = newData;                Provider.Save(license);            }        }        private static void AssertLicense()        {            var result = CheckLicenseValidity(out var license, out var licenseData);            if (IsReadOnly)            {                if(result == LicenseValidation.Valid)                {                    EndReadOnly();                }                return;            }            // TODO: Switch to real system            if(result != LicenseValidation.Valid)            {                var newLicense = LicenseUtils.GenerateNewLicense();                if (LicenseUtils.TryEncryptLicense(newLicense, out var newData, out var error))                {                    license.Data = newData;                    Provider.Save(license);                }                else                {                    Logger.Send(LogType.Error, "", $"Error updating license: {error}");                }                return;            }            else            {                return;            }            switch (result)            {                case LicenseValidation.Valid:                    LogLicenseExpiry(licenseData!.Expiry);                    UpdateValidLicense(license, licenseData);                    break;                case LicenseValidation.Missing:                    LogImportant("Database is unlicensed!");                    BeginReadOnly();                    break;                case LicenseValidation.Expired:                    LogImportant("Database license has expired!");                    BeginReadOnly();                    break;                case LicenseValidation.Corrupt:                    LogImportant("Database license is corrupt - you will need to renew your license.");                    BeginReadOnly();                    break;                case LicenseValidation.Tampered:                    LogImportant("Database license has been tampered with - you will need to renew your license.");                    BeginReadOnly();                    break;            }        }        #endregion        #region Logging        private static void LogMessage(LogType type, string message)        {            Logger.Send(type, "", message);        }        private static void LogInfo(string message)        {            Logger.Send(LogType.Information, "", message);        }        private static void LogImportant(string message)        {            Logger.Send(LogType.Important, "", message);        }        private static void LogError(string message)        {            Logger.Send(LogType.Error, "", message);        }        #endregion        public static void InitStores()        {            foreach (var storetype in stores)            {                var store = Activator.CreateInstance(storetype) as IStore;                store.Provider = Provider;                store.Init();            }        }        public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, string platform, string version)             where TEntity : Entity, new()        {            var defType = typeof(Store<>).MakeGenericType(typeof(TEntity));            Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();            var store = (Store<TEntity>)Activator.CreateInstance(subType ?? defType)!;            store.Provider = Provider;            store.UserGuid = userguid;            store.UserID = userid;            store.Platform = platform;            store.Version = version;            return store;        }        private static CoreTable DoQueryMultipleQuery<TEntity>(            IQueryDef query,            Guid userguid, string userid, string platform, string version)             where TEntity : Entity, new()        {            var store = FindStore<TEntity>(userguid, userid, platform, version);            return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);        }        public static Dictionary<string, CoreTable> QueryMultiple(            Dictionary<string, IQueryDef> queries,             Guid userguid, string userid, string platform, string version)        {            var result = new Dictionary<string, CoreTable>();            var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;            var tasks = new List<Task>();            foreach (var item in queries)                tasks.Add(Task.Run(() =>                {                    result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(Provider, new object[]                    {                        item.Value,                        userguid, userid, platform, version                    }) as CoreTable)!;                }));            Task.WaitAll(tasks.ToArray());            return result;        }        #region Supported Types        private class ModuleConfiguration : Dictionary<string, bool>, LocalConfigurationSettings        {        }        private static Type[]? _dbtypes;        public static IEnumerable<string> SupportedTypes()        {            _dbtypes ??= LoadSupportedTypes();            return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));        }        private static Type[] LoadSupportedTypes()        {            var result = new List<Type>();            var path = Provider.URL.ToLower();            var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();            var bChanged = false;            foreach (var type in Entities)            {                var key = type.EntityName();                if (config.ContainsKey(key))                {                    if (config[key])                        //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));                        result.Add(type);                    else                        Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));                }                else                {                    //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));                    config[key] = true;                    result.Add(type);                    bChanged = true;                }            }            if (bChanged)                new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);            return result.ToArray();        }        public static bool IsSupported<T>() where T : Entity        {            _dbtypes ??= LoadSupportedTypes();            return _dbtypes.Contains(typeof(T));        }        #endregion        //public static void OpenSession(bool write)        //{        //	Provider.OpenSession(write);        //}        //public static void CloseSession()        //{        //	Provider.CloseSession();        //}        #region Private Methods        public static void LoadScripts()        {            Logger.Send(LogType.Information, "", "Loading Script Cache...");            LoadedScripts.Clear();            var scripts = Provider.Load(                new Filter<Script>                        (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)                    .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)                    .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)                    .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)                    .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)                    .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)                    .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)            );            foreach (var script in scripts)            {                var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());                var doc = new ScriptDocument(script.Code);                if (doc.Compile())                {                    Logger.Send(LogType.Information, "",                        string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));                    LoadedScripts[key] = doc;                }                else                {                    Logger.Send(LogType.Error, "",                        string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));                }            }            Logger.Send(LogType.Information, "", "Loading Script Cache Complete");        }        //private static Type[] entities = null;        //private static void SetEntityTypes(Type[] types)        //{        //	foreach (Type type in types)        //	{        //		if (!type.IsSubclassOf(typeof(Entity)))        //			throw new Exception(String.Format("{0} is not a valid entity", type.Name));        //	}        //	entities = types;        //}        private static Type[] stores = { };        private static void SetStoreTypes(Type[] types)        {            types = types.Where(                myType => myType.IsClass                    && !myType.IsAbstract                    && !myType.IsGenericType).ToArray();            foreach (var type in types)                if (!type.GetInterfaces().Contains(typeof(IStore)))                    throw new Exception(string.Format("{0} is not a valid store", type.Name));            stores = types;        }        private static Type[] ConsolidatedObjectModel()        {            // Add the core types from InABox.Core            var types = new List<Type>();            //var coreTypes = CoreUtils.TypeList(            //	new Assembly[] { typeof(Entity).Assembly },            //	myType =>            //	myType.IsClass            //	&& !myType.IsAbstract            //	&& !myType.IsGenericType            //	&& myType.IsSubclassOf(typeof(Entity))            //	&& myType.GetInterfaces().Contains(typeof(IRemotable))            //);            //types.AddRange(coreTypes);            // Now add the end-user object model            types.AddRange(Entities.Where(x =>                x.GetTypeInfo().IsClass                && !x.GetTypeInfo().IsGenericType                && x.GetTypeInfo().IsSubclassOf(typeof(Entity))            ));            return types.ToArray();        }        private enum SchemaStatus        {            New,            Changed,            Validated        }        private static Dictionary<string, Type> GetSchema()        {            var model = new Dictionary<string, Type>();            var objectmodel = ConsolidatedObjectModel();            foreach (var type in objectmodel)            {                Dictionary<string, Type> thismodel = CoreUtils.PropertyList(type, x => true, true);                foreach (var key in thismodel.Keys)                    model[type.Name + "." + key] = thismodel[key];            }            return model;            //return Serialization.Serialize(model, Formatting.Indented);        }        private static SchemaStatus ValidateSchema()        {            var db_schema = Provider.GetSchema();            if (db_schema.Count() == 0)                return SchemaStatus.New;            var mdl_json = Serialization.Serialize(GetSchema());            var db_json = Serialization.Serialize(db_schema);            return mdl_json.Equals(db_json) ? SchemaStatus.Validated : SchemaStatus.Changed;        }        private static void SaveSchema()        {            Provider.SaveSchema(GetSchema());        }        #endregion    }}
 |