1 module shark.impl.postgresql;
2 
3 import std.algorithm : canFind;
4 import std.conv : to;
5 static import std.datetime;
6 import std.digest : toHexString, LetterCase;
7 import std.digest.md : md5Of;
8 import std.exception : enforce;
9 import std.experimental.logger : trace, info, warning;
10 import std.socket;
11 import std..string : join, replace;
12 import std.system : Endian;
13 
14 import shark.clause;
15 import shark.database : DatabaseException, DatabaseConnectionException, ErrorCodeDatabaseException, ErrorCodesDatabaseException;
16 import shark.sql : SqlDatabase;
17 import shark.util : Stream, read0String, write0String, fromHexString;
18 
19 import xbuffer : Buffer;
20 
21 // debug
22 import std.stdio;
23 
24 private enum infoStatement = "_shark_table_info";
25 
26 private alias PostgresqlStream = Stream!(1, Endian.bigEndian, 4, true);
27 
28 /**
29  * PostgreSQL database implementation.
30  */
31 class PostgresqlDatabase : SqlDatabase {
32 
33 	private PostgresqlStream _stream;
34 
35 	private string[string] _status;
36 
37 	private uint _serverProcessId;
38 	private uint _serverSecretKey;
39 
40 	private bool _error = false;
41 
42 	public this(string host, ushort port=5432) {
43 		Socket socket = new TcpSocket();
44 		socket.blocking = true;
45 		socket.connect(getAddress(host, port)[0]);
46 		_stream = new PostgresqlStream(socket, 1024);
47 	}
48 
49 	protected override void connectImpl(string db, string user, string password) {
50 		Buffer buffer = new Buffer(64);
51 		buffer.write!(Endian.bigEndian, uint)(0x0003_0000);
52 		buffer.write0String("user");
53 		buffer.write0String(user);
54 		buffer.write0String("database");
55 		buffer.write0String(db);
56 		buffer.write(ubyte(0));
57 		buffer.write!(Endian.bigEndian, uint)(buffer.data.length.to!uint + 4, 0);
58 		_stream.socket.send(buffer.data);
59 		buffer = receive();
60 		enforcePacketSequence('R');
61 		immutable method = buffer.read!(Endian.bigEndian, uint)();
62 		bool passwordRequired = true;
63 		string hashedPassword;
64 		switch(method) {
65 			case 0:
66 				// no password required
67 				passwordRequired = false;
68 				break;
69 			case 3:
70 				// plain text password
71 				hashedPassword = password;
72 				break;
73 			case 5:
74 				// hashed password (default)
75 				void[] salt = buffer.readData(4);
76 				hashedPassword = "md5" ~ toHexString!(LetterCase.lower)(md5Of(toHexString!(LetterCase.lower)(md5Of(password, user)), salt)).idup;
77 				break;
78 			default:
79 				throw new DatabaseConnectionException("Unknown authentication method requested by the server (" ~ method.to!string ~ ")");
80 		}
81 		if(passwordRequired) {
82 			buffer.reset();
83 			_stream.id = "p";
84 			buffer.write0String(hashedPassword);
85 			_stream.send(buffer);
86 			buffer = receive();
87 			enforcePacketSequence('R');
88 			enforce!DatabaseConnectionException(buffer.read!(Endian.bigEndian, uint)() == 0, "Authentication failed");
89 		}
90 		bool loop = true;
91 		do {
92 			buffer = receive();
93 			switch(_stream.id!char[0]) {
94 				case 'Z':
95 					// ready for query
96 					loop = false;
97 					break;
98 				case 'S':
99 					// parameter status
100 					_status[buffer.read0String().idup] = buffer.read0String().idup;
101 					break;
102 				case 'K':
103 					// backend key data
104 					_serverProcessId = buffer.read!(Endian.bigEndian, uint)();
105 					_serverSecretKey = buffer.read!(Endian.bigEndian, uint)();
106 					//loop = false;
107 					break;
108 				default:
109 					throw new DatabaseConnectionException("Wrong packet sequence");
110 			}
111 		} while(loop);
112 		// prepare a statement for table description
113 		prepareQuery(infoStatement, "select column_name, data_type, is_nullable, character_maximum_length, column_default from INFORMATION_SCHEMA.COLUMNS where table_name=$1;", Param.VARCHAR);
114 	}
115 	
116 	protected override void closeImpl() {
117 		_stream.socket.close();
118 	}
119 
120 	private Buffer receive() {
121 		if(_error) {
122 			// clear packets received after an exception was thrown and not handled
123 			_error = false;
124 			string[] ids;
125 			do {
126 				receive();
127 				ids ~= _stream.id!char;
128 			} while(_stream.id!char[0] != 'Z');
129 			warning("An exception was thrown and ", ids.length, " packet(s) (", ids.join(", "), ") has been skipped");
130 		}
131 		Buffer buffer = _stream.receive();
132 		switch(_stream.id!char[0]) {
133 			case 'E':
134 				_error = true;
135 				PostgresqlDatabaseException[] exceptions;
136 				char errorCode;
137 				while((errorCode = buffer.read!char()) != '\0') {
138 					exceptions ~= new PostgresqlDatabaseException(errorCode, buffer.read0String());
139 				}
140 				throw new PostgresqlDatabaseExceptions(exceptions);
141 			case 'N':
142 				string[] notices;
143 				char noticeCode;
144 				while((noticeCode = buffer.read!char()) != '\0') {
145 					notices ~= buffer.read0String();
146 				}
147 				enforce!DatabaseConnectionException(notices.length >= 3, "Received malformed notice with " ~ notices.length.to!string ~ " fields");
148 				info("PostgreSQL (", notices[0], "): ", notices[3]);
149 				return receive();
150 			default:
151 				return buffer;
152 		}
153 	}
154 
155 	private void sendFlush() {
156 		_stream.id = "H";
157 		Buffer buffer = new Buffer(5);
158 		_stream.send(buffer);
159 	}
160 
161 	// QUERYING
162 
163 	public override void query(string query) {
164 		trace("Running query `" ~ query ~ "`");
165 		_stream.id = "Q";
166 		Buffer buffer = new Buffer(query.length + 6);
167 		buffer.write0String(query);
168 		_stream.send(buffer);
169 	}
170 
171 	public void prepareQuery(string statement, string query, Param[] params...) {
172 		trace("Preparing statement `" ~ statement ~ "` using `" ~ query ~ "`");
173 		_stream.id = "P";
174 		Buffer buffer = new Buffer(statement.length + query.length + 9 + params.length * 4);
175 		buffer.write0String(statement);
176 		buffer.write0String(query);
177 		buffer.write!(Endian.bigEndian, ushort)(params.length.to!ushort);
178 		foreach(param ; params) buffer.write!(Endian.bigEndian, uint)(param);
179 		_stream.send(buffer);
180 		sendFlush();
181 		receive();
182 		enforcePacketSequence('1');
183 	}
184 
185 	public void executeQuery(string statement, Prepared.Param[] params...) {
186 		trace("Executing prepared statement `" ~ statement ~ "` with parameters " ~ params.to!string);
187 		Buffer buffer = new Buffer(512);
188 		_stream.id = "B";
189 		immutable length = params.length.to!ushort;
190 		buffer.write0String("");
191 		buffer.write0String(statement);
192 		buffer.write!(Endian.bigEndian, ushort)(length);
193 		foreach(param ; params) {
194 			static if(is(typeof(param) : string)) buffer.write!(Endian.bigEndian, ushort)(false);
195 			else buffer.write!(Endian.bigEndian, ushort)(true);
196 		}
197 		buffer.write!(Endian.bigEndian, ushort)(length);
198 		void writeImpl(T)(T value) {
199 			auto str = value.to!string;
200 			buffer.write!(Endian.bigEndian, uint)(str.length.to!uint);
201 			buffer.write(str);
202 		}
203 		foreach(param ; params) {
204 			if(param is null) {
205 				buffer.write!(Endian.bigEndian, uint)(uint.max);
206 			} else {
207 				final switch(param.type) with(Type) {
208 					case BOOL:
209 						writeImpl(param.to!string[0]);
210 						break;
211 					case BYTE:
212 					case SHORT:
213 					case INT:
214 					case LONG:
215 					case FLOAT:
216 					case DOUBLE:
217 					case CHAR:
218 					case STRING:
219 					case CLOB:
220 					case DATE:
221 					case DATETIME:
222 					case TIME:
223 						writeImpl(param.to!string);
224 						break;
225 					case BINARY:
226 					case BLOB:
227 						writeImpl(cast(string)(cast(Prepared.ParamImpl!(ubyte[], Type.BINARY))param).value);
228 						break;
229 				}
230 			}
231 		}
232 		buffer.write!(Endian.bigEndian, ushort)(1);
233 		buffer.write!(Endian.bigEndian, ushort)(1);
234 		_stream.send(buffer);
235 		buffer.reset();
236 		_stream.id = "E";
237 		buffer.write0String("");
238 		buffer.write(0);
239 		_stream.send(buffer);
240 		buffer.reset();
241 		_stream.id = "S";
242 		_stream.send(buffer);
243 		receiveAndEnforcePacketSequence('2');
244 	}
245 	
246 	public override Result querySelect(string query) {
247 		Result result;
248 		this.query(query);
249 		Buffer buffer = receive();
250 		if(_stream.id!char[0] != 'C') {
251 			enforcePacketSequence('T');
252 			ColumnInfo[] columns;
253 			foreach(i ; 0..buffer.read!(Endian.bigEndian, ushort)()) {
254 				ColumnInfo column;
255 				column.column = buffer.read0String().idup;
256 				buffer.readData(6);
257 				column.type = buffer.read!(Endian.bigEndian, uint)();
258 				buffer.readData(8);
259 				columns ~= column;
260 				result.columns[column.column] = i;
261 			}
262 			while(true) {
263 				buffer = receive();
264 				if(_stream.id!char[0] == 'C') break;
265 				enforcePacketSequence('D');
266 				enforce!DatabaseConnectionException(buffer.read!(Endian.bigEndian, ushort)() == columns.length, "Length of the row doesn't match the column's");
267 				Result.Row[] rows;
268 				foreach(column ; columns) {
269 					rows ~= parseRow(column.type, buffer);
270 				}
271 				result.rows ~= rows;
272 			}
273 		}
274 		enforceReadyForQuery();
275 		return result;
276 	}
277 
278 	private Result.Row parseRow(uint param, Buffer buffer) {
279 		switch(param) with(Param) {
280 			case BOOL: return readString!bool(buffer);
281 			case BYTEA: return readString!(ubyte[])(buffer);
282 			case INT8: return readString!long(buffer);
283 			case INT2: return readString!short(buffer);
284 			case INT4: return readString!int(buffer);
285 			case TEXT: return readString!string(buffer);
286 			case FLOAT4: return readString!float(buffer);
287 			case FLOAT8: return readString!double(buffer);
288 			case CHAR: return readString!char(buffer);
289 			case VARCHAR: return readString!string(buffer);
290 			case DATE: return readString!(std.datetime.Date)(buffer);
291 			case TIMESTAMP: return readString!(std.datetime.DateTime)(buffer);
292 			case TIME: return readString!(std.datetime.TimeOfDay)(buffer);
293 			default: throw new DatabaseConnectionException("Unknwon type with id " ~ param.to!string);
294 		}
295 	}
296 	
297 	private static struct ColumnInfo {
298 		
299 		string column;
300 		uint type;
301 
302 	}
303 
304 	// CREATE | ALTER
305 
306 	protected override TableInfo[string] getTableInfo(string table) {
307 		executeQuery(infoStatement, Prepared.prepare(table));
308 		TableInfo[string] ret;
309 		while(true) {
310 			Buffer buffer = receive();
311 			if(_stream.id!char[0] == 'C') break;
312 			enforcePacketSequence('D');
313 			enforce!DatabaseConnectionException(buffer.read!(Endian.bigEndian, ushort)() == 5, "Wrong number of fields returned by the server");
314 			TableInfo field;
315 			field.name = buffer.read!string(buffer.read!(Endian.bigEndian, uint)()).idup;
316 			field.type = fromStringToType(buffer.read!string(buffer.read!(Endian.bigEndian, uint)()));
317 			field.nullable = buffer.read!string(buffer.read!(Endian.bigEndian, uint)()) == "YES";
318 			immutable length = buffer.read!(Endian.bigEndian, uint)();
319 			if(length != uint.max) field.length = buffer.read!(Endian.bigEndian, uint)();
320 			immutable defaultValue = buffer.read!(Endian.bigEndian, uint)();
321 			if(defaultValue != uint.max) field.defaultValue = buffer.read!string(defaultValue).idup;
322 			ret[field.name] = field;
323 		}
324 		enforceReadyForQuery();
325 		return ret;
326 	}
327 
328 	private uint fromStringToType(string str) {
329 		switch(str) with(Type) {
330 			case "boolean": return BOOL;
331 			case "smallint": return SHORT;
332 			case "integer": return INT;
333 			case "bigint": return LONG;
334 			case "real": return FLOAT;
335 			case "character": return CHAR;
336 			case "double precision": return DOUBLE;
337 			case "character varying": return STRING;
338 			case "bytea": return BINARY | BLOB;
339 			case "text": return CLOB;
340 			case "date": return DATE;
341 			case "timestamp": case "timestamp without time zone": return DATETIME;
342 			case "time": case "time without time zone": return TIME;
343 			default: throw new DatabaseException("Unknown type '" ~ str ~ "'");
344 		}
345 	}
346 
347 	protected override string generateField(InitInfo.Field field) {
348 		string[] ret = [field.name];
349 		ret ~= fromTypeToString(cast(Type)field.type, field.autoIncrement, field.length);
350 		if(field.length) ret[1] ~= "(" ~ field.length.to!string ~ ")";
351 		if(!field.nullable) ret ~= "not null";
352 		if(field.unique) ret ~= "unique";
353 		return ret.join(" ");
354 	}
355 	
356 	private string fromTypeToString(Type type, bool autoIncrement, ref size_t length) {
357 		final switch(type) with(Type) {
358 			case BOOL: return "boolean";
359 			case BYTE: throw new DatabaseException("Type byte is not supported");
360 			case SHORT: return autoIncrement ? "serial2" : "int2";
361 			case INT: return autoIncrement ? "serial4" : "int4";
362 			case LONG: return autoIncrement ? "serial8" : "int8";
363 			case FLOAT: return "float4";
364 			case DOUBLE: return "float8";
365 			case CHAR:
366 				length = 1;
367 				return "char";
368 			case STRING: return "varchar";
369 			case BINARY:
370 			case BLOB:
371 				length = 0; // bytea(x) not supported
372 				return "bytea";
373 			case CLOB: return "text";
374 			case DATE: return "date";
375 			case DATETIME: return "timestamp";
376 			case TIME: return "time";
377 		}
378 	}
379 
380 	protected override void createTable(string table, string[] fields) {
381 		super.createTable(table, fields);
382 		receiveAndEnforcePacketSequence('C');
383 		enforceReadyForQuery();
384 	}
385 
386 	protected override void alterTableColumn(string table, InitInfo.Field field, bool typeChanged, bool nullableChanged) {
387 		string q = "alter table " ~ table ~ " alter column " ~ field.name;
388 		if(typeChanged) {
389 			q ~= " type " ~ fromTypeToString(cast(Type)field.type, false, field.length);
390 			if(field.length) q ~= "(" ~ field.length.to!string ~ ")";
391 		}
392 		if(nullableChanged) {
393 			if(field.nullable) q ~= " drop not null";
394 			else q ~= " set not null";
395 		}
396 		query(q ~ ";");
397 		receiveAndEnforcePacketSequence('C');
398 		enforceReadyForQuery();
399 	}
400 
401 	protected override void alterTableAddColumn(string table, InitInfo.Field field) {
402 		super.alterTableAddColumn(table, field);
403 		receiveAndEnforcePacketSequence('C');
404 		enforceReadyForQuery();
405 	}
406 
407 	protected override void alterTableDropColumn(string table, string column) {
408 		super.alterTableDropColumn(table, column);
409 		receiveAndEnforcePacketSequence('C');
410 		enforceReadyForQuery();
411 	}
412 
413 	// INSERT
414 
415 	protected override Result insertImpl(InsertInfo insertInfo) {
416 		auto ret = super.insertImpl(insertInfo);
417 		receiveAndEnforcePacketSequence('C');
418 		enforceReadyForQuery();
419 		return ret;
420 	}
421 
422 	protected override Result insertInto(string table, string[] names, string[] fields, string[] primaryKeys) {
423 		string q = "insert into " ~ table ~ " (" ~ names.join(",") ~ ") values (" ~ fields.join(",") ~ ")";
424 		if(primaryKeys.length) q ~= " returning " ~ primaryKeys.join(",");
425 		query(q);
426 		if(primaryKeys.length) {
427 			Result result;
428 			Buffer buffer = receive();
429 			enforce(buffer.read!(Endian.bigEndian, ushort)() == primaryKeys.length, "Wrong number of fields returned by the server");
430 			ColumnInfo[] info;
431 			foreach(i ; 0..primaryKeys.length) {
432 				ColumnInfo column;
433 				column.column = buffer.read0String().idup;
434 				buffer.readData(6); // ???
435 				column.type = buffer.read!(Endian.bigEndian, uint)();
436 				buffer.readData(8); // ???
437 				info ~= column;
438 				result.columns[column.column] = i;
439 			}
440 			buffer = receive();
441 			enforce(buffer.read!(Endian.bigEndian, ushort)() == primaryKeys.length, "Wrong number of fields returned by the server");
442 			Result.Row[] rows;
443 			foreach(column ; info) {
444 				rows ~= parseRow(column.type, buffer);
445 			}
446 			result.rows ~= rows;
447 			return result;
448 		} else {
449 			return Result.init;
450 		}
451 	}
452 
453 	// UPDATE
454 
455 	protected override void updateImpl(UpdateInfo updateInfo, Clause.Where where) {
456 		super.updateImpl(updateInfo, where);
457 		receiveAndEnforcePacketSequence('C');
458 		enforceReadyForQuery();
459 	}
460 
461 	// DELETE
462 
463 	protected override void deleteImpl(string table, Clause.Where where) {
464 		super.deleteImpl(table, where);
465 		receiveAndEnforcePacketSequence('C');
466 		enforceReadyForQuery();
467 	}
468 
469 	// DROP
470 
471 	public override void dropIfExists(string table) {
472 		super.dropIfExists(table);
473 		receiveAndEnforcePacketSequence('C');
474 		enforceReadyForQuery();
475 	}
476 
477 	public override void drop(string table) {
478 		super.drop(table);
479 		receiveAndEnforcePacketSequence('C');
480 		enforceReadyForQuery();
481 	}
482 
483 	// UTILS
484 
485 	private void enforcePacketSequence(char expected) {
486 		immutable current = _stream.id!char[0];
487 		if(current != expected) throw new WrongPacketSequenceException(expected, current);
488 	}
489 
490 	private void receiveAndEnforcePacketSequence(char expected) {
491 		receive();
492 		enforcePacketSequence(expected);
493 	}
494 
495 	private void enforceReadyForQuery() {
496 		Buffer buffer = receive();
497 		enforcePacketSequence('Z');
498 		enforce!DatabaseConnectionException(buffer.data.length == 1 && "ITE".canFind(buffer.read!char()), "Server is not ready for query");
499 	}
500 
501 	protected override string randomFunction() {
502 		return "random()";
503 	}
504 
505 	protected override string escapeBinary(ubyte[] value) {
506 		return "'\\x" ~ toHexString(value) ~ "'";
507 	}
508 
509 	private enum Param : uint {
510 
511 		BOOL = 16,
512 		BYTEA = 17,
513 		INT8 = 20,
514 		INT2 = 21,
515 		INT4 = 23,
516 		TEXT = 25,
517 		FLOAT4 = 700,
518 		FLOAT8 = 701,
519 		CHAR = 1042,
520 		VARCHAR = 1043,
521 		DATE = 1082,
522 		TIME = 1083,
523 		TIMESTAMP = 1114,
524 
525 	}
526 	
527 	private Result.Row readString(T)(Buffer buffer) {
528 		immutable length = buffer.read!(Endian.bigEndian, uint)();
529 		if(length == uint.max) return null;
530 		else {
531 			auto str = buffer.read!string(length).idup;
532 			static if(is(T == ubyte[])) {
533 				assert(str.length > 2 && str.length % 2 == 0);
534 				return Result.Row.from(fromHexString(str[2..$]));
535 			} else static if(is(T == std.datetime.Date)) {
536 				return Result.Row.from(std.datetime.Date.fromISOExtString(str));
537 			} else static if(is(T == std.datetime.DateTime)) {
538 				return Result.Row.from(std.datetime.DateTime.fromISOExtString(str.replace(" ", "T")));
539 			} else static if(is(T == std.datetime.TimeOfDay)) {
540 				return Result.Row.from(std.datetime.TimeOfDay.fromISOExtString(str));
541 			} else static if(is(T == bool)) {
542 				return Result.Row.from(str == "t");
543 			} else {
544 				return Result.Row.from(to!T(str));
545 			}
546 		}
547 	}
548 
549 }
550 
551 class WrongPacketSequenceException : DatabaseConnectionException {
552 
553 	public this(char expected, char got, string file=__FILE__, size_t line=__LINE__) {
554 		super("Wrong packet sequence (expected '" ~ expected ~ "' but got '" ~ got ~ "')", file, line);
555 	}
556 
557 }
558 
559 alias PostgresqlDatabaseException = ErrorCodeDatabaseException!("PostgreSQL", char);
560 
561 alias PostgresqlDatabaseExceptions = ErrorCodesDatabaseException!PostgresqlDatabaseException;