1 module shark.sql;
2 
3 import std.conv : to;
4 static import std.datetime;
5 import std.exception : enforce;
6 import std.experimental.logger : warning;
7 import std.string : join;
8 
9 import shark.clause;
10 import shark.database;
11 import shark.entity;
12 
13 import xbuffer : Buffer;
14 
15 // debug
16 import std.stdio;
17 
18 /**
19  * Generic SQL database. It is possible to execute queries and
20  * select queries.
21  * See specific implementations for more complex operations.
22  */
23 abstract class SqlDatabase : Database {
24 
25 	/**
26 	 * Runs a query without receiving anything back.
27 	 * Note that running just this method may break some implementations.
28 	 * Example:
29 	 * ---
30 	 * database.query("drop table test;");
31 	 * ---
32 	 */
33 	public abstract void query(string);
34 
35 	/**
36 	 * Runs a select query and returns the result. This method
37 	 * does not break the flow of the protocol like `query` does.
38 	 * This method is intended for usage with complex queries.
39 	 * Example:
40 	 * ---
41 	 * auto result = database.querySelect("select * from test order by rand() limit 1");
42 	 * result.bind!Test();
43 	 * ---
44 	 */
45 	public abstract Result querySelect(string);
46 
47 	// CREATE | ALTER
48 
49 	protected override void initImpl(InitInfo initInfo) {
50 		TableInfo[string] tableInfo = getTableInfo(initInfo.tableName);
51 		if(tableInfo is null) {
52 			//create the table
53 			string[] fields;
54 			foreach(field ; initInfo.fields) {
55 				fields ~= generateField(field);
56 			}
57 			if(initInfo.primaryKeys.length) {
58 				fields ~= "primary key(" ~ initInfo.primaryKeys.join(",") ~ ")";
59 			}
60 			createTable(initInfo.tableName, fields);
61 		} else {
62 			// alter the table
63 			foreach(field ; initInfo.fields) {
64 				auto ptr = field.name in tableInfo;
65 				if(ptr) {
66 					// compare
67 					//enforce!DatabaseException(field.type == ptr.type, "Type cannot be changed!");
68 					if((field.type & ptr.type) == 0 || field.nullable != ptr.nullable) {
69 						alterTableColumn(initInfo.tableName, field, (field.type & ptr.type) == 0, field.nullable != ptr.nullable);
70 					}
71 				} else {
72 					// field added
73 					alterTableAddColumn(initInfo.tableName, field);
74 				}
75 				tableInfo.remove(field.name);
76 			}
77 			foreach(name, field; tableInfo) {
78 				// field removed, just drop it
79 				alterTableDropColumn(initInfo.tableName, name);
80 			}
81 		}
82 	}
83 
84 	/**
85 	 * Returns: table info or null if the table doesn't exists.
86 	 */
87 	protected abstract TableInfo[string] getTableInfo(string table);
88 
89 	protected static struct TableInfo {
90 
91 		string name;
92 
93 		uint type;
94 
95 		size_t length;
96 
97 		bool nullable;
98 
99 		string defaultValue = null;
100 
101 	}
102 
103 	protected abstract string generateField(InitInfo.Field field);
104 
105 	protected void createTable(string table, string[] fields) {
106 		query("create table " ~ table ~ " (" ~ fields.join(",") ~ ");");
107 	}
108 
109 	protected abstract void alterTableColumn(string table, InitInfo.Field field, bool typeChanged, bool nullableChanged);
110 
111 	protected void alterTableAddColumn(string table, InitInfo.Field field) {
112 		query("alter table " ~ table ~ " add " ~ generateField(field) ~ ";");
113 	}
114 
115 	protected void alterTableDropColumn(string table, string column) {
116 		query("alter table " ~ table ~ " drop " ~ column ~ ";");
117 	}
118 
119 	// SELECT
120 
121 	protected override Result selectImpl(SelectInfo selectInfo, Select select) {
122 		string where;
123 		string[] order;
124 		if(select.where.statement !is null) {
125 			where = stringifyStatements(select.where.statement);
126 		}
127 		if(select.order.rand) {
128 			order ~= randomFunction;
129 		} else if(select.order.fields.length) {
130 			foreach(field ; select.order.fields) {
131 				order ~= field.name ~ " " ~ (field._asc ? "asc" : "desc");
132 			}
133 		}
134 		string q = "select " ~ (selectInfo.fields.length ? selectInfo.fields.join(",") : "*") ~ " from " ~ selectInfo.tableName;
135 		if(where.length) q ~= " where " ~ where;
136 		if(order.length) q ~= " order by " ~ order.join(",");
137 		if(select.limit.upper != 0) {
138 			if(select.limit.lower == 0) q ~= " limit " ~ select.limit.upper.to!string;
139 			else q ~= " limit " ~ select.limit.lower.to!string ~ "," ~ select.limit.upper.to!string;
140 		}
141 		return querySelect(q ~ ";");
142 	}
143 
144 	// INSERT
145 
146 	protected override Result insertImpl(InsertInfo insertInfo) {
147 		string[] names;
148 		string[] values;
149 		foreach(field ; insertInfo.fields) {
150 			names ~= field.name;
151 			values ~= field.value;
152 		}
153 		return insertInto(insertInfo.tableName, names, values, insertInfo.primaryKeys);
154 	}
155 
156 	protected abstract Result insertInto(string table, string[] names, string[] fields, string[] primaryKeys);
157 
158 	// UPDATE
159 
160 	protected override void updateImpl(UpdateInfo updateInfo, Clause.Where where) {
161 		string[] sets;
162 		foreach(field ; updateInfo.fields) {
163 			sets ~= field.name ~ "=" ~ field.value;
164 		}
165 		string q = "update " ~ updateInfo.tableName ~ " set " ~ sets.join(",");
166 		if(where.statement !is null) q ~= " where " ~ stringifyStatements(where.statement);
167 		else warning("Where statement is empty! Updating the whole table!");
168 		query(q ~ ";");
169 	}
170 
171 	// DELETE
172 
173 	protected override void deleteImpl(string table, Clause.Where where) {
174 		string q = "delete from " ~ table;
175 		if(where.statement !is null) q ~= " where " ~ stringifyStatements(where.statement);
176 		else warning("Where statement is empty! Deleting the whole table!");
177 		query(q ~ ";");
178 	}
179 
180 	// DROP
181 
182 	public override void dropIfExists(string table) {
183 		query("drop table if exists " ~ table ~ ";");
184 	}
185 
186 	public override void drop(string table) {
187 		query("drop table " ~ table ~ ";");
188 	}
189 
190 	// UTILS
191 
192 	protected string stringifyStatements(Clause.Where.GenericStatement statement) {
193 		auto complex = cast(Clause.Where.ComplexStatement)statement;
194 		if(complex) {
195 			return "(" ~ stringifyStatements(complex.leftStatement) ~ ") " ~ glueToString(complex.glue) ~ " (" ~ stringifyStatements(complex.rightStatement) ~ ")";
196 		} else {
197 			auto simple = cast(Clause.Where.Statement)statement;
198 			assert(simple !is null);
199 			if(simple.needsEscaping) return simple.field ~ " " ~ operatorToString(simple.operator) ~ " " ~ escape(simple.value);
200 			else return simple.field ~ " " ~ operatorToString(simple.operator) ~ " " ~ simple.value;
201 		}
202 	}
203 
204 	protected string operatorToString(Clause.Where.Operator operator) {
205 		final switch(operator) with(Clause.Where.Operator) {
206 			case isNull: return "is";
207 			case equals: return "=";
208 			case notEquals: return "!=";
209 			case greaterThan: return ">";
210 			case greaterThanOrEquals: return ">=";
211 			case lessThan: return "<";
212 			case lessThanOrEquals: return "<=";
213 		}
214 	}
215 
216 	protected string glueToString(Clause.Where.Glue glue) {
217 		final switch(glue) with(Clause.Where.Glue) {
218 			case or: return "or";
219 			case and: return "and";
220 		}
221 	}
222 	
223 	protected abstract @property string randomFunction();
224 
225 	protected override string escapeString(string value) {
226 		import std.string : replace;
227 		return "'" ~ value.replace("'", "''") ~ "'";
228 	}
229 
230 	protected override string escapeDate(std.datetime.Date value) {
231 		return "'" ~ value.toISOExtString() ~ "'";
232 	}
233 
234 	protected override string escapeDateTime(std.datetime.DateTime value) {
235 		return "'" ~ value.toISOExtString() ~ "'";
236 	}
237 
238 	protected override string escapeTime(std.datetime.TimeOfDay value) {
239 		return "'" ~ value.toISOExtString() ~ "'";
240 	}
241 
242 	/**
243 	 * Utilities for prepared statements.
244 	 */
245 	public static struct Prepared {
246 
247 		static interface Param {
248 
249 			public @property Type type();
250 
251 		}
252 
253 		static class ParamImpl(T, Type _type) : Param {
254 
255 			public T value;
256 
257 			public override Type type() {
258 				return _type;
259 			}
260 
261 			public this(T value) {
262 				this.value = value;
263 			}
264 
265 			override string toString() {
266 				import std.conv : to;
267 				return value.to!string;
268 			}
269 
270 			alias value this;
271 
272 		}
273 
274 		static Param[] prepare(E...)(E params) {
275 			Param[] ret;
276 			foreach(param ; params) {
277 				alias T = typeof(param);
278 				static if(is(T == Bool) || is(T == bool)) ret ~= new ParamImpl!(bool, Type.BOOL)(param);
279 				else static if(is(T == Byte) || is(T == byte) || is(T == ubyte)) ret ~= new ParamImpl!(byte, Type.BYTE)(param);
280 				else static if(is(T == Short) || is(T == short) || is(T == ushort)) ret ~= new ParamImpl!(short, Type.SHORT)(param);
281 				else static if(is(T == Integer) || is(T == int) || is(T == uint)) ret ~= new ParamImpl!(int, Type.INT)(param);
282 				else static if(is(T == Long) || is(T == long) || is(T == ulong)) ret ~= new ParamImpl!(long, Type.LONG)(param);
283 				// ...
284 				else static if(is(T == String) || is(T == string)) ret ~= new ParamImpl!(string, Type.STRING)(param);
285 				else static assert(0, "Type " ~ T.stringof ~ " not supported");
286 			}
287 			return ret;
288 		}
289 
290 	}
291 
292 }