Browse Source

Added code to initialise the nulls in the DB

Kenric Nugteren 5 days ago
parent
commit
abef83358e
3 changed files with 273 additions and 140 deletions
  1. 4 4
      InABox.Database/DbFactory.cs
  2. 9 0
      InABox.Database/VersionNumber.cs
  3. 260 136
      inabox.database.sqlite/SQLiteProvider.cs

+ 4 - 4
InABox.Database/DbFactory.cs

@@ -109,9 +109,9 @@ public static class DbFactory
         LoadScripts();
     }
     
-    public static DatabaseVersion GetVersionSettings()
+    public static DatabaseVersion GetVersionSettings(IProviderFactory? factory = null)
     {
-        var provider = ProviderFactory.NewProvider(Logger.New());
+        var provider = (factory ?? ProviderFactory).NewProvider(Logger.New());
         var result = provider.Query(Filter<GlobalSettings>.Where(x => x.Section).IsEqualTo(nameof(DatabaseVersion)))
             .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
         if(result != null)
@@ -130,9 +130,9 @@ public static class DbFactory
         return dbVersion;
     }
 
-    public static VersionNumber GetDatabaseVersion()
+    public static VersionNumber GetDatabaseVersion(IProviderFactory? factory = null)
     {
-        var dbVersion = GetVersionSettings();
+        var dbVersion = GetVersionSettings(factory: factory);
         return VersionNumber.Parse(dbVersion.Version);
     }
 

+ 9 - 0
InABox.Database/VersionNumber.cs

@@ -154,6 +154,15 @@ public class VersionNumber : IComparable
 
     public static bool operator ==(VersionNumber a, VersionNumber b)
     {
+        if(a is null)
+        {
+            return b is null;
+        }
+        else if(b is null)
+        {
+            return false;
+        }
+
         if (a.IsDevelopmentVersion)
             return b.IsDevelopmentVersion;
         if (b.IsDevelopmentVersion)

+ 260 - 136
inabox.database.sqlite/SQLiteProvider.cs

@@ -371,11 +371,12 @@ public class SQLiteProviderFactory : IProviderFactory
         {
             Log(LogType.Information, $"Creating Table: {nameof(CustomProperty)}");
 
-            CreateTable(access, typeof(CustomProperty), true, []);
+            CreateTable(access, typeof(CustomProperty), true, [], false);
         }
         else
         {
-            CheckFields(access, typeof(CustomProperty), value.Item1, []);
+            CheckFields(access, typeof(CustomProperty), value.Item1, [],
+                initialiseNulls: false);
         }
 
         var customproperties = MainProvider.Load<CustomProperty>(); // Filter<CustomProperty>.Where(x => x.Class).IsEqualTo(type.EntityName()))
@@ -391,7 +392,8 @@ public class SQLiteProviderFactory : IProviderFactory
                 if (!metadata.ContainsKey(table))
                 {
                     Log(LogType.Information, "Creating Table: " + type.Name);
-                    CreateTable(access, type, true, customproperties);
+                    CreateTable(access, type, true, customproperties,
+                        initialiseNulls: false);
                 }
             }
         }
@@ -403,51 +405,14 @@ public class SQLiteProviderFactory : IProviderFactory
             if (type.GetCustomAttribute<AutoEntity>() == null)
             {
                 table = type.EntityName().Split('.').Last();
-                CheckFields(access, type, metadata[table].Item1, customproperties);
+                CheckFields(access, type, metadata[table].Item1, customproperties,
+                    initialiseNulls: false);
             }
         }
 
-        metadata = LoadMetaData();
-        
-        foreach (var type in ordered)
-        {
-            if (type.GetCustomAttribute<AutoEntity>() == null)
-            {
-                table = type.Name;
-                CheckTriggers(access, type, metadata[table].Item2);
-            }
-        }
-
-        metadata = LoadMetaData();
-
-        foreach (var type in ordered)
-        {
-            if (type.GetCustomAttribute<AutoEntity>() == null)
-            {
-
-                table = type.EntityName().Split('.').Last();
-                CheckIndexes(access, type, metadata[table].Item3);
-            }
-        }
-
-        metadata = LoadMetaData();
-        
-        foreach (var type in ordered)
-        {
-            if (type.GetCustomAttribute<AutoEntity>() != null)
-            {
-                table = type.EntityName().Split('.').Last();
-                if (!metadata.ContainsKey(table))
-                {
-                    Log(LogType.Information, "Creating Table: " + type.EntityName().Split('.').Last());
-                    CreateTable(access, type, true, customproperties);
-                }
-                else
-                {
-                    CheckFields(access, type, metadata[table].Item1, customproperties);
-                }
-            }
-        }
+        LoadTriggers(access, ordered);
+        LoadIndexes(access, ordered);
+        LoadViews(access, ordered, customproperties, initialiseNulls: false);
 
         if (bForceRebuild)
         {
@@ -485,6 +450,124 @@ public class SQLiteProviderFactory : IProviderFactory
             Triggers = new List<MetadataEntry>();
         }
     }
+
+    // Skipped starting '[' already
+    private void ParseField(string sql, ref int i, Dictionary<string, string> fields)
+    {
+        var j = i;
+        while(j < sql.Length)
+        {
+            if (sql[j] == ']')
+            {
+                break;
+            }
+            ++j;
+        }
+
+        var fieldName = sql[i..j];
+
+        ++j;
+        while(j < sql.Length && char.IsWhiteSpace(sql[j]))
+        {
+            ++j;
+        }
+
+        i = j;
+
+        while(i < sql.Length)
+        {
+            if (sql[i] == '\'')
+            {
+                while(i < sql.Length && sql[i] != '\'')
+                {
+                    ++i;
+                }
+                ++i;
+            }
+            else if (sql[i] == ',')
+            {
+                break;
+            }
+            else if (sql[i] == ')')
+            {
+                break;
+            }
+            else
+            {
+                ++i;
+            }
+        }
+
+        fields[fieldName] = sql[j..i].Trim();
+    }
+
+    private Dictionary<string, string> ParseTableMetadata(string table, string sql, bool isTable)
+    {
+        var tableFields = new Dictionary<string, string>();
+
+        if (isTable)
+        {
+            var i = sql.IndexOf('(') + 1;
+            if(i == 0)
+            {
+                throw new Exception($"Expected 'CREATE TABLE {table} ('");
+            }
+
+            while (i < sql.Length)
+            {
+                var c = sql[i];
+                if (Char.IsWhiteSpace(c))
+                {
+                    ++i;
+                }
+                else if(c == ')')
+                {
+                    break;
+                }
+                else if(c == '[')
+                {
+                    ++i;
+                    ParseField(sql, ref i, tableFields);
+                }
+                else
+                {
+                    ++i;
+                }
+            }
+        }
+        else
+        {
+            sql = sql.Replace("\"", "")
+                .Replace("DISTINCT ", "");
+            sql = sql.Split(new String[] { " AS SELECT " }, StringSplitOptions.TrimEntries).Last();
+            sql = sql.Split(new String[] { " FROM " }, StringSplitOptions.TrimEntries).First();
+
+            var fields = sql.Replace("\n\t", "").Replace("\t", " ").Replace("\"", "").Trim().Split(',');
+            foreach (var fld in fields)
+            {
+                var field = fld.Trim()
+                    .Replace("\t", " ")
+                    .Replace("\"", "")
+                    .Replace("[", "").Replace("]", "");
+
+                var parts = field.Split(" as ");
+                if(parts.Length == 1)
+                {
+                    tableFields[field] = "";
+                }
+                else if(parts.Length == 2)
+                {
+                    field = parts[1];
+                    if (parts[0] != "NULL")
+                    {
+                        tableFields[field] = "";
+                    }
+                }
+            }
+        }
+
+        return tableFields;
+    }
     
     private Dictionary<string, Tuple<Dictionary<string, string>, Dictionary<string, string>, Dictionary<string, string>>> LoadMetaData()
     {
@@ -505,77 +588,18 @@ public class SQLiteProviderFactory : IProviderFactory
                     if (reader.HasRows)
                         while (reader.Read())
                         {
-                            var tblinfo = new Tuple<Dictionary<string, string>, Dictionary<string, string>, Dictionary<string, string>>(
-                                new Dictionary<string, string>(), new Dictionary<string, string>(), new Dictionary<string, string>());
-
                             var table = reader.GetString(0);
                             var sql = reader.GetString(1);
-                            bool istable = String.Equals(reader.GetString(2),"table");
+                            var istable = String.Equals(reader.GetString(2),"table");
 
-                            if (istable)
+                            try
                             {
-                                sql = sql.Replace("\"", "")
-                                    .Replace(string.Format("CREATE TABLE {0} (", table), "");
-                                sql = sql.Remove(sql.Length - 1).Trim();
-                                var fields = sql.Replace("\n\t", "").Replace("\t", " ").Replace("\"", "").Trim().Split(',');
-                                var primarykey = "";
-                                foreach (var fld in fields)
-                                {
-                                    var field = fld.Trim().Replace("\t", " ").Replace("\"", "").Replace("[", "").Replace("]", "");
-                                    //if (field.ToUpper().StartsWith("CONSTRAINT"))
-                                    //    tblinfo.Item2.Add(field);
-                                    if (field.ToUpper().StartsWith("PRIMARY KEY"))
-                                    {
-                                        primarykey = field.Replace("PRIMARY KEY(", "").Replace(")", "");
-                                    }
-                                    else
-                                    {
-                                        var comps = field.Split(' ');
-
-                                        tblinfo.Item1[comps[0]] = string.Format("{0}{1}", comps[1],
-                                            field.Contains("PRIMARY KEY") ? " PRIMARY KEY" : "");
-                                    }
-                                }
-
-                                if (!string.IsNullOrEmpty(primarykey))
-                                {
-                                    var pkfld = tblinfo.Item1[primarykey];
-                                    if (!pkfld.ToUpper().Contains("PRIMARY KEY"))
-                                        tblinfo.Item1[primarykey] = string.Format("{0} PRIMARY KEY", pkfld.Trim());
-                                }
+                                metadata[table] = new(ParseTableMetadata(table, sql, istable), new(), new());
                             }
-                            else
+                            catch(Exception e)
                             {
-                                sql = sql.Replace("\"", "")
-                                    .Replace("DISTINCT ", "");
-                                sql = sql.Split(new String[] { " AS SELECT " }, StringSplitOptions.TrimEntries).Last();
-                                sql = sql.Split(new String[] { " FROM " }, StringSplitOptions.TrimEntries).First();
-
-                                var fields = sql.Replace("\n\t", "").Replace("\t", " ").Replace("\"", "").Trim().Split(',');
-                                foreach (var fld in fields)
-                                {
-                                    var field = fld.Trim()
-                                        .Replace("\t", " ")
-                                        .Replace("\"", "")
-                                        .Replace("[", "").Replace("]", "");
-
-                                    var parts = field.Split(" as ");
-                                    if(parts.Length == 1)
-                                    {
-                                        tblinfo.Item1[field] = "";
-                                    }
-                                    else if(parts.Length == 2)
-                                    {
-                                        field = parts[1];
-                                        if (parts[0] != "NULL")
-                                        {
-                                            tblinfo.Item1[field] = "";
-                                        }
-                                    }
-                                }
+                                Log(LogType.Error, $"Invalid table metadata for {table}: {CoreUtils.FormatException(e)}");
                             }
-
-                            metadata[table] = tblinfo;
                         }
 
                     reader.Close();
@@ -622,6 +646,55 @@ public class SQLiteProviderFactory : IProviderFactory
         return metadata;
     }
 
+    private void LoadTriggers(SQLiteWriteAccessor access, IEnumerable<Type> ordered)
+    {
+        var metadata = LoadMetaData();
+        
+        foreach (var type in ordered)
+        {
+            if (type.GetCustomAttribute<AutoEntity>() == null)
+            {
+                CheckTriggers(access, type, metadata[type.Name].Item2);
+            }
+        }
+    }
+
+    private void LoadIndexes(SQLiteWriteAccessor access, IEnumerable<Type> ordered)
+    {
+        var metadata = LoadMetaData();
+
+        foreach (var type in ordered)
+        {
+            if (type.GetCustomAttribute<AutoEntity>() == null)
+            {
+
+                CheckIndexes(access, type, metadata[type.Name].Item3);
+            }
+        }
+    }
+
+    private void LoadViews(SQLiteWriteAccessor access, IEnumerable<Type> ordered, CustomProperty[] customProperties,
+        bool initialiseNulls)
+    {
+        var metadata = LoadMetaData();
+        
+        foreach (var type in ordered)
+        {
+            if (type.GetCustomAttribute<AutoEntity>() != null)
+            {
+                if (!metadata.TryGetValue(type.Name, out var value))
+                {
+                    Log(LogType.Information, $"Creating Table: {type.Name}");
+                    CreateTable(access, type, true, customProperties, initialiseNulls: initialiseNulls);
+                }
+                else
+                {
+                    CheckFields(access, type, value.Item1, customProperties, initialiseNulls: initialiseNulls);
+                }
+            }
+        }
+    }
+
     private static void LoadType(Type type, List<Type> into)
     {
         if (into.Contains(type))
@@ -684,7 +757,7 @@ public class SQLiteProviderFactory : IProviderFactory
         return "TEXT";
     }
 
-    private void LoadFields(Type type, Dictionary<string, string> fields)
+    private void LoadFields(Type type, Dictionary<string, string> fields, bool initialiseNulls)
     {
         AutoEntity? view = type.GetCustomAttribute<AutoEntity>();
         Type definition = view?.Generator != null
@@ -693,9 +766,11 @@ public class SQLiteProviderFactory : IProviderFactory
 
         foreach(var property in DatabaseSchema.Properties(definition).Where(x => x.IsDBColumn))
         {
+            var defaultValue = SQLiteProvider.EscapeValue(SQLiteProvider.GetColumnDefaultValue(property.PropertyType));
             fields[property.Name] =
                 ColumnType(property.PropertyType)
-                + (property.Name.Equals("ID") ? " PRIMARY KEY" : "");
+                + (property.Name.Equals("ID") ? " PRIMARY KEY" : "")
+                + (initialiseNulls ? $" DEFAULT {defaultValue}" : "");
         }
     }
 
@@ -827,13 +902,13 @@ public class SQLiteProviderFactory : IProviderFactory
                     Log(LogType.Information, "Creating Table: " + type.EntityName().Split('.').Last());
                     using (var access = MainProvider.GetWriteAccess())
                     {
-                        CreateTable(access, type, true, customproperties);
+                        CreateTable(access, type, true, customproperties, initialiseNulls: false);
                     }
                 }
                 else
                 {
                     var type_fields = new Dictionary<string, string>();
-                    LoadFields(view.Generator.Definition, type_fields);
+                    LoadFields(view.Generator.Definition, type_fields, initialiseNulls: false);
 
                     using (var access = MainProvider.GetWriteAccess())
                     {
@@ -844,10 +919,10 @@ public class SQLiteProviderFactory : IProviderFactory
         }
     }
 
-    private Dictionary<string, object?> CheckDefaultColumns(IAutoEntityGenerator generator)
+    private Dictionary<string, object?> CheckDefaultColumns(IAutoEntityGenerator generator, bool initialiseNulls)
     {
         var viewfields = new Dictionary<string, string>();
-        LoadFields(generator.Definition, viewfields);
+        LoadFields(generator.Definition, viewfields, initialiseNulls: initialiseNulls);
         Dictionary<String, object?> result = new Dictionary<string, object?>();
         if (!viewfields.ContainsKey("ID"))
             result["ID"] = null;
@@ -862,7 +937,7 @@ public class SQLiteProviderFactory : IProviderFactory
         return result;
     }
     
-    private void CreateTable(SQLiteWriteAccessor access, Type type, bool includeconstraints, CustomProperty[] customproperties)
+    private void CreateTable(SQLiteWriteAccessor access, Type type, bool includeconstraints, CustomProperty[] customproperties, bool initialiseNulls)
     {
         var tablename = type.EntityName().Split('.').Last();
         var ddl = new List<string>();
@@ -891,13 +966,13 @@ public class SQLiteProviderFactory : IProviderFactory
                 {
 
                     var columns = new List<IBaseColumn>();
-                    var constants = CheckDefaultColumns(union);
+                    var constants = CheckDefaultColumns(union, initialiseNulls: initialiseNulls);
                     
                     var interfacefields = new Dictionary<string, string>();
-                    LoadFields(union.Definition, interfacefields);
+                    LoadFields(union.Definition, interfacefields, initialiseNulls: initialiseNulls);
                     
                     var entityfields = new Dictionary<string, string>();
-                    LoadFields(table.Entity, entityfields);
+                    LoadFields(table.Entity, entityfields, initialiseNulls: initialiseNulls);
 
                     foreach (var field in interfacefields.Keys)
                     {
@@ -951,7 +1026,7 @@ public class SQLiteProviderFactory : IProviderFactory
             else if ( view.Generator is IAutoEntityCrossGenerator cross)
             {
                 List<String> constants = new List<string>();
-                foreach (var constant in CheckDefaultColumns(cross))
+                foreach (var constant in CheckDefaultColumns(cross, initialiseNulls: initialiseNulls))
                     constants.Add($"{SQLiteProvider.EscapeValue(constant.Value)} as [{constant.Key}]");
                 
                 String query = String.Format(
@@ -994,7 +1069,7 @@ public class SQLiteProviderFactory : IProviderFactory
                 foreach (var constant in cartesian.Constants)
                     fields.Add($"{SQLiteProvider.EscapeValue(constant.Constant)} as [{constant.Mapping.Property}]");
                 
-                foreach (var constant in CheckDefaultColumns(cartesian))
+                foreach (var constant in CheckDefaultColumns(cartesian, initialiseNulls: initialiseNulls))
                     fields.Add($"{SQLiteProvider.EscapeValue(constant.Value)} as [{constant.Key}]");
 
                 StringBuilder sb = new StringBuilder();
@@ -1078,7 +1153,7 @@ public class SQLiteProviderFactory : IProviderFactory
             var fields = new Dictionary<string, string>();
             var constraints = new List<string>();
             var indexes = new List<string>();
-            LoadFields(type, fields);
+            LoadFields(type, fields, initialiseNulls: initialiseNulls);
             var defs = new List<string>();
             foreach (var key in fields.Keys)
                 defs.Add(string.Format("[{0}] {1}", key, fields[key]));
@@ -1102,7 +1177,8 @@ public class SQLiteProviderFactory : IProviderFactory
 
     private void RebuildTable(SQLiteWriteAccessor access, Type type, Dictionary<string, string> table_fields,
         Dictionary<string, string> type_fields,
-        CustomProperty[] customproperties)
+        CustomProperty[] customproperties,
+        bool initialiseNulls = false)
     {
 
         var table = type.EntityName().Split('.').Last();
@@ -1137,7 +1213,7 @@ public class SQLiteProviderFactory : IProviderFactory
                 if (!String.IsNullOrWhiteSpace(drop))
                     MainProvider.ExecuteSQL(access, string.Format("DROP {0} {1};", drop, table));
 
-                CreateTable(access, type, true, customproperties);
+                CreateTable(access, type, true, customproperties, initialiseNulls: initialiseNulls);
 
             }
             catch (Exception e)
@@ -1229,18 +1305,36 @@ public class SQLiteProviderFactory : IProviderFactory
                     if (existingtable)
                         MainProvider.ExecuteSQL(access, string.Format("ALTER TABLE {0} RENAME TO _{0}_old;", table));
 
-                    CreateTable(access, type, true, customproperties);
-
-                    var fields = new List<string>();
+                    CreateTable(access, type, true, customproperties, initialiseNulls: initialiseNulls);
+                    var fields = new List<Tuple<string, string>>();
                     foreach (var field in type_fields.Keys)
                         if (table_fields.ContainsKey(field))
-                            fields.Add("[" + field + "]");
+                        {
+                            if (initialiseNulls)
+                            {
+                                if (field == nameof(Entity.ID))
+                                {
+                                    fields.Add(new($"[{field}]", $"[{field}]"));
+                                }
+                                else
+                                {
+                                    var defValue = SQLiteProvider.EscapeValue(SQLiteProvider.GetColumnDefaultValue(DatabaseSchema.PropertyStrict(type, field).PropertyType));
+                                    fields.Add(new($"[{field}]", $"IFNULL([{field}],{defValue}) as [{field}]"));
+                                }
+                            }
+                            else
+                            {
+                                fields.Add(new($"[{field}]", $"[{field}]"));
+                            }
+                        }
 
                     if (existingtable)
                     {
                         MainProvider.ExecuteSQL(access,
-                            string.Format("INSERT INTO {0} ({1}) SELECT {1} FROM _{0}_old;", table,
-                                string.Join(", ", fields)));
+                            string.Format("INSERT INTO {0} ({1}) SELECT {2} FROM _{0}_old;",
+                                table,
+                                string.Join(", ", fields.Select(x => x.Item1)),
+                                string.Join(", ", fields.Select(x => x.Item2))));
                         MainProvider.ExecuteSQL(access, string.Format("DROP TABLE _{0}_old;", table));
                     }
 
@@ -1258,11 +1352,12 @@ public class SQLiteProviderFactory : IProviderFactory
         
     }
 
-    private void CheckFields(SQLiteWriteAccessor access, Type type, Dictionary<string, string> current_fields, CustomProperty[] customproperties)
+    private void CheckFields(SQLiteWriteAccessor access, Type type, Dictionary<string, string> current_fields, CustomProperty[] customproperties,
+        bool initialiseNulls)
     {
         var type_fields = new Dictionary<string, string>();
         var view = type.GetCustomAttribute<AutoEntity>();
-        LoadFields(type, type_fields);
+        LoadFields(type, type_fields, initialiseNulls: initialiseNulls);
 
         var bRebuild = false;
         foreach (var field in type_fields.Keys)
@@ -1339,6 +1434,35 @@ public class SQLiteProviderFactory : IProviderFactory
 
     #endregion
 
+    public void InitializeNulls()
+    {
+        using var access = MainProvider.GetWriteAccess();
+
+        var ordered = new List<Type>();
+        foreach (var type in Types)
+            LoadType(type, ordered);
+
+        var metadata = LoadMetaData();
+        var customproperties = MainProvider.Load<CustomProperty>();
+        
+        foreach (var type in ordered)
+        {
+            if (type.GetCustomAttribute<AutoEntity>() == null)
+            {
+                var type_fields = new Dictionary<string, string>();
+                LoadFields(type, type_fields,
+                    initialiseNulls: true);
+
+                RebuildTable(access, type, metadata[type.Name].Item1, type_fields, customproperties,
+                    initialiseNulls: true);
+            }
+        }
+
+        LoadTriggers(access, ordered);
+        LoadIndexes(access, ordered);
+        LoadViews(access, ordered, customproperties, true);
+    }
+
     private void Log(LogType type, string message)
     {
         Logger.Send(type, "", message);
@@ -1954,7 +2078,7 @@ public class SQLiteProvider : IProvider
         return sParam;
     }
 
-    private static object? GetFilterDefaultValue(Type type)
+    internal static object? GetColumnDefaultValue(Type type)
     {
         if(type == typeof(string))
         {
@@ -2105,7 +2229,7 @@ public class SQLiteProvider : IProvider
                     }
                     else
                     {
-                        var strProp = $"IFNULL({prop},{EscapeValue(GetFilterDefaultValue(filter.Type))})";
+                        var strProp = $"IFNULL({prop},{EscapeValue(GetColumnDefaultValue(filter.Type))})";
                         value = useparams ? EncodeParameter(command, value) : EscapeValue(filter.Value);
                         result = string.Format("(" + operators[filter.Operator] + ")", strProp, value);
                     }
@@ -2113,9 +2237,9 @@ public class SQLiteProvider : IProvider
             }
             else
             {
-                var strProp = $"IFNULL({prop},{EscapeValue(GetFilterDefaultValue(filter.Type))})";
+                var strProp = $"IFNULL({prop},{EscapeValue(GetColumnDefaultValue(filter.Type))})";
                 var strValue = filter.Value is FilterConstant constant
-                    ? constant == FilterConstant.Null ? EscapeValue(GetFilterDefaultValue(filter.Type)) : GetFilterConstant(constant)
+                    ? constant == FilterConstant.Null ? EscapeValue(GetColumnDefaultValue(filter.Type)) : GetFilterConstant(constant)
                     : useparams ? EncodeParameter(command, Encode(filter.Value, filter.Type, convertToNull: false)) : EscapeValue(filter.Value);
 
                 result = string.Format($"({operators[filter.Operator]})", strProp, strValue);