1 module shark.sql; 2 3 import std.conv : to; 4 static import std.datetime; 5 import std.exception : enforce; 6 import std.experimental.logger : warning; 7 import std.string : join; 8 9 import shark.clause; 10 import shark.database; 11 import shark.entity; 12 13 import xbuffer : Buffer; 14 15 // debug 16 import std.stdio; 17 18 /** 19 * Generic SQL database. It is possible to execute queries and 20 * select queries. 21 * See specific implementations for more complex operations. 22 */ 23 abstract class SqlDatabase : Database { 24 25 /** 26 * Runs a query without receiving anything back. 27 * Note that running just this method may break some implementations. 28 * Example: 29 * --- 30 * database.query("drop table test;"); 31 * --- 32 */ 33 public abstract void query(string); 34 35 /** 36 * Runs a select query and returns the result. This method 37 * does not break the flow of the protocol like `query` does. 38 * This method is intended for usage with complex queries. 39 * Example: 40 * --- 41 * auto result = database.querySelect("select * from test order by rand() limit 1"); 42 * result.bind!Test(); 43 * --- 44 */ 45 public abstract Result querySelect(string); 46 47 // CREATE | ALTER 48 49 protected override void initImpl(InitInfo initInfo) { 50 TableInfo[string] tableInfo = getTableInfo(initInfo.tableName); 51 if(tableInfo is null) { 52 //create the table 53 string[] fields; 54 foreach(field ; initInfo.fields) { 55 fields ~= generateField(field); 56 } 57 if(initInfo.primaryKeys.length) { 58 fields ~= "primary key(" ~ initInfo.primaryKeys.join(",") ~ ")"; 59 } 60 createTable(initInfo.tableName, fields); 61 } else { 62 // alter the table 63 foreach(field ; initInfo.fields) { 64 auto ptr = field.name in tableInfo; 65 if(ptr) { 66 // compare 67 //enforce!DatabaseException(field.type == ptr.type, "Type cannot be changed!"); 68 if((field.type & ptr.type) == 0 || field.nullable != ptr.nullable) { 69 alterTableColumn(initInfo.tableName, field, (field.type & ptr.type) == 0, field.nullable != ptr.nullable); 70 } 71 } else { 72 // field added 73 alterTableAddColumn(initInfo.tableName, field); 74 } 75 tableInfo.remove(field.name); 76 } 77 foreach(name, field; tableInfo) { 78 // field removed, just drop it 79 alterTableDropColumn(initInfo.tableName, name); 80 } 81 } 82 } 83 84 /** 85 * Returns: table info or null if the table doesn't exists. 86 */ 87 protected abstract TableInfo[string] getTableInfo(string table); 88 89 protected static struct TableInfo { 90 91 string name; 92 93 uint type; 94 95 size_t length; 96 97 bool nullable; 98 99 string defaultValue = null; 100 101 } 102 103 protected abstract string generateField(InitInfo.Field field); 104 105 protected void createTable(string table, string[] fields) { 106 query("create table " ~ table ~ " (" ~ fields.join(",") ~ ");"); 107 } 108 109 protected abstract void alterTableColumn(string table, InitInfo.Field field, bool typeChanged, bool nullableChanged); 110 111 protected void alterTableAddColumn(string table, InitInfo.Field field) { 112 query("alter table " ~ table ~ " add " ~ generateField(field) ~ ";"); 113 } 114 115 protected void alterTableDropColumn(string table, string column) { 116 query("alter table " ~ table ~ " drop " ~ column ~ ";"); 117 } 118 119 // SELECT 120 121 protected override Result selectImpl(SelectInfo selectInfo, Select select) { 122 string where; 123 string[] order; 124 if(select.where.statement !is null) { 125 where = stringifyStatements(select.where.statement); 126 } 127 if(select.order.rand) { 128 order ~= randomFunction; 129 } else if(select.order.fields.length) { 130 foreach(field ; select.order.fields) { 131 order ~= field.name ~ " " ~ (field._asc ? "asc" : "desc"); 132 } 133 } 134 string q = "select " ~ (selectInfo.fields.length ? selectInfo.fields.join(",") : "*") ~ " from " ~ selectInfo.tableName; 135 if(where.length) q ~= " where " ~ where; 136 if(order.length) q ~= " order by " ~ order.join(","); 137 if(select.limit.upper != 0) { 138 if(select.limit.lower == 0) q ~= " limit " ~ select.limit.upper.to!string; 139 else q ~= " limit " ~ select.limit.lower.to!string ~ "," ~ select.limit.upper.to!string; 140 } 141 return querySelect(q ~ ";"); 142 } 143 144 // INSERT 145 146 protected override Result insertImpl(InsertInfo insertInfo) { 147 string[] names; 148 string[] values; 149 foreach(field ; insertInfo.fields) { 150 names ~= field.name; 151 values ~= field.value; 152 } 153 return insertInto(insertInfo.tableName, names, values, insertInfo.primaryKeys); 154 } 155 156 protected abstract Result insertInto(string table, string[] names, string[] fields, string[] primaryKeys); 157 158 // UPDATE 159 160 protected override void updateImpl(UpdateInfo updateInfo, Clause.Where where) { 161 string[] sets; 162 foreach(field ; updateInfo.fields) { 163 sets ~= field.name ~ "=" ~ field.value; 164 } 165 string q = "update " ~ updateInfo.tableName ~ " set " ~ sets.join(","); 166 if(where.statement !is null) q ~= " where " ~ stringifyStatements(where.statement); 167 else warning("Where statement is empty! Updating the whole table!"); 168 query(q ~ ";"); 169 } 170 171 // DELETE 172 173 protected override void deleteImpl(string table, Clause.Where where) { 174 string q = "delete from " ~ table; 175 if(where.statement !is null) q ~= " where " ~ stringifyStatements(where.statement); 176 else warning("Where statement is empty! Deleting the whole table!"); 177 query(q ~ ";"); 178 } 179 180 // DROP 181 182 public override void dropIfExists(string table) { 183 query("drop table if exists " ~ table ~ ";"); 184 } 185 186 public override void drop(string table) { 187 query("drop table " ~ table ~ ";"); 188 } 189 190 // UTILS 191 192 protected string stringifyStatements(Clause.Where.GenericStatement statement) { 193 auto complex = cast(Clause.Where.ComplexStatement)statement; 194 if(complex) { 195 return "(" ~ stringifyStatements(complex.leftStatement) ~ ") " ~ glueToString(complex.glue) ~ " (" ~ stringifyStatements(complex.rightStatement) ~ ")"; 196 } else { 197 auto simple = cast(Clause.Where.Statement)statement; 198 assert(simple !is null); 199 if(simple.needsEscaping) return simple.field ~ " " ~ operatorToString(simple.operator) ~ " " ~ escape(simple.value); 200 else return simple.field ~ " " ~ operatorToString(simple.operator) ~ " " ~ simple.value; 201 } 202 } 203 204 protected string operatorToString(Clause.Where.Operator operator) { 205 final switch(operator) with(Clause.Where.Operator) { 206 case isNull: return "is"; 207 case equals: return "="; 208 case notEquals: return "!="; 209 case greaterThan: return ">"; 210 case greaterThanOrEquals: return ">="; 211 case lessThan: return "<"; 212 case lessThanOrEquals: return "<="; 213 } 214 } 215 216 protected string glueToString(Clause.Where.Glue glue) { 217 final switch(glue) with(Clause.Where.Glue) { 218 case or: return "or"; 219 case and: return "and"; 220 } 221 } 222 223 protected abstract @property string randomFunction(); 224 225 protected override string escapeString(string value) { 226 import std.string : replace; 227 return "'" ~ value.replace("'", "''") ~ "'"; 228 } 229 230 protected override string escapeDate(std.datetime.Date value) { 231 return "'" ~ value.toISOExtString() ~ "'"; 232 } 233 234 protected override string escapeDateTime(std.datetime.DateTime value) { 235 return "'" ~ value.toISOExtString() ~ "'"; 236 } 237 238 protected override string escapeTime(std.datetime.TimeOfDay value) { 239 return "'" ~ value.toISOExtString() ~ "'"; 240 } 241 242 /** 243 * Utilities for prepared statements. 244 */ 245 public static struct Prepared { 246 247 static interface Param { 248 249 public @property Type type(); 250 251 } 252 253 static class ParamImpl(T, Type _type) : Param { 254 255 public T value; 256 257 public override Type type() { 258 return _type; 259 } 260 261 public this(T value) { 262 this.value = value; 263 } 264 265 override string toString() { 266 import std.conv : to; 267 return value.to!string; 268 } 269 270 alias value this; 271 272 } 273 274 static Param[] prepare(E...)(E params) { 275 Param[] ret; 276 foreach(param ; params) { 277 alias T = typeof(param); 278 static if(is(T == Bool) || is(T == bool)) ret ~= new ParamImpl!(bool, Type.BOOL)(param); 279 else static if(is(T == Byte) || is(T == byte) || is(T == ubyte)) ret ~= new ParamImpl!(byte, Type.BYTE)(param); 280 else static if(is(T == Short) || is(T == short) || is(T == ushort)) ret ~= new ParamImpl!(short, Type.SHORT)(param); 281 else static if(is(T == Integer) || is(T == int) || is(T == uint)) ret ~= new ParamImpl!(int, Type.INT)(param); 282 else static if(is(T == Long) || is(T == long) || is(T == ulong)) ret ~= new ParamImpl!(long, Type.LONG)(param); 283 // ... 284 else static if(is(T == String) || is(T == string)) ret ~= new ParamImpl!(string, Type.STRING)(param); 285 else static assert(0, "Type " ~ T.stringof ~ " not supported"); 286 } 287 return ret; 288 } 289 290 } 291 292 }