1 module shark.impl.postgresql; 2 3 import std.algorithm : canFind; 4 import std.conv : to; 5 static import std.datetime; 6 import std.digest : toHexString, LetterCase; 7 import std.digest.md : md5Of; 8 import std.exception : enforce; 9 import std.experimental.logger : trace, info, warning; 10 import std.socket; 11 import std..string : join, replace; 12 import std.system : Endian; 13 14 import shark.clause; 15 import shark.database : DatabaseException, DatabaseConnectionException, ErrorCodeDatabaseException, ErrorCodesDatabaseException; 16 import shark.sql : SqlDatabase; 17 import shark.util : Stream, read0String, write0String, fromHexString; 18 19 import xbuffer : Buffer; 20 21 // debug 22 import std.stdio; 23 24 private enum infoStatement = "_shark_table_info"; 25 26 private alias PostgresqlStream = Stream!(1, Endian.bigEndian, 4, true); 27 28 /** 29 * PostgreSQL database implementation. 30 */ 31 class PostgresqlDatabase : SqlDatabase { 32 33 private PostgresqlStream _stream; 34 35 private string[string] _status; 36 37 private uint _serverProcessId; 38 private uint _serverSecretKey; 39 40 private bool _error = false; 41 42 public this(string host, ushort port=5432) { 43 Socket socket = new TcpSocket(); 44 socket.blocking = true; 45 socket.connect(getAddress(host, port)[0]); 46 _stream = new PostgresqlStream(socket, 1024); 47 } 48 49 protected override void connectImpl(string db, string user, string password) { 50 Buffer buffer = new Buffer(64); 51 buffer.write!(Endian.bigEndian, uint)(0x0003_0000); 52 buffer.write0String("user"); 53 buffer.write0String(user); 54 buffer.write0String("database"); 55 buffer.write0String(db); 56 buffer.write(ubyte(0)); 57 buffer.write!(Endian.bigEndian, uint)(buffer.data.length.to!uint + 4, 0); 58 _stream.socket.send(buffer.data); 59 buffer = receive(); 60 enforcePacketSequence('R'); 61 immutable method = buffer.read!(Endian.bigEndian, uint)(); 62 bool passwordRequired = true; 63 string hashedPassword; 64 switch(method) { 65 case 0: 66 // no password required 67 passwordRequired = false; 68 break; 69 case 3: 70 // plain text password 71 hashedPassword = password; 72 break; 73 case 5: 74 // hashed password (default) 75 void[] salt = buffer.readData(4); 76 hashedPassword = "md5" ~ toHexString!(LetterCase.lower)(md5Of(toHexString!(LetterCase.lower)(md5Of(password, user)), salt)).idup; 77 break; 78 default: 79 throw new DatabaseConnectionException("Unknown authentication method requested by the server (" ~ method.to!string ~ ")"); 80 } 81 if(passwordRequired) { 82 buffer.reset(); 83 _stream.id = "p"; 84 buffer.write0String(hashedPassword); 85 _stream.send(buffer); 86 buffer = receive(); 87 enforcePacketSequence('R'); 88 enforce!DatabaseConnectionException(buffer.read!(Endian.bigEndian, uint)() == 0, "Authentication failed"); 89 } 90 bool loop = true; 91 do { 92 buffer = receive(); 93 switch(_stream.id!char[0]) { 94 case 'Z': 95 // ready for query 96 loop = false; 97 break; 98 case 'S': 99 // parameter status 100 _status[buffer.read0String().idup] = buffer.read0String().idup; 101 break; 102 case 'K': 103 // backend key data 104 _serverProcessId = buffer.read!(Endian.bigEndian, uint)(); 105 _serverSecretKey = buffer.read!(Endian.bigEndian, uint)(); 106 //loop = false; 107 break; 108 default: 109 throw new DatabaseConnectionException("Wrong packet sequence"); 110 } 111 } while(loop); 112 // prepare a statement for table description 113 prepareQuery(infoStatement, "select column_name, data_type, is_nullable, character_maximum_length, column_default from INFORMATION_SCHEMA.COLUMNS where table_name=$1;", Param.VARCHAR); 114 } 115 116 protected override void closeImpl() { 117 _stream.socket.close(); 118 } 119 120 private Buffer receive() { 121 if(_error) { 122 // clear packets received after an exception was thrown and not handled 123 _error = false; 124 string[] ids; 125 do { 126 receive(); 127 ids ~= _stream.id!char; 128 } while(_stream.id!char[0] != 'Z'); 129 warning("An exception was thrown and ", ids.length, " packet(s) (", ids.join(", "), ") has been skipped"); 130 } 131 Buffer buffer = _stream.receive(); 132 switch(_stream.id!char[0]) { 133 case 'E': 134 _error = true; 135 PostgresqlDatabaseException[] exceptions; 136 char errorCode; 137 while((errorCode = buffer.read!char()) != '\0') { 138 exceptions ~= new PostgresqlDatabaseException(errorCode, buffer.read0String()); 139 } 140 throw new PostgresqlDatabaseExceptions(exceptions); 141 case 'N': 142 string[] notices; 143 char noticeCode; 144 while((noticeCode = buffer.read!char()) != '\0') { 145 notices ~= buffer.read0String(); 146 } 147 enforce!DatabaseConnectionException(notices.length >= 3, "Received malformed notice with " ~ notices.length.to!string ~ " fields"); 148 info("PostgreSQL (", notices[0], "): ", notices[3]); 149 return receive(); 150 default: 151 return buffer; 152 } 153 } 154 155 private void sendFlush() { 156 _stream.id = "H"; 157 Buffer buffer = new Buffer(5); 158 _stream.send(buffer); 159 } 160 161 // QUERYING 162 163 public override void query(string query) { 164 trace("Running query `" ~ query ~ "`"); 165 _stream.id = "Q"; 166 Buffer buffer = new Buffer(query.length + 6); 167 buffer.write0String(query); 168 _stream.send(buffer); 169 } 170 171 public void prepareQuery(string statement, string query, Param[] params...) { 172 trace("Preparing statement `" ~ statement ~ "` using `" ~ query ~ "`"); 173 _stream.id = "P"; 174 Buffer buffer = new Buffer(statement.length + query.length + 9 + params.length * 4); 175 buffer.write0String(statement); 176 buffer.write0String(query); 177 buffer.write!(Endian.bigEndian, ushort)(params.length.to!ushort); 178 foreach(param ; params) buffer.write!(Endian.bigEndian, uint)(param); 179 _stream.send(buffer); 180 sendFlush(); 181 receive(); 182 enforcePacketSequence('1'); 183 } 184 185 public void executeQuery(string statement, Prepared.Param[] params...) { 186 trace("Executing prepared statement `" ~ statement ~ "` with parameters " ~ params.to!string); 187 Buffer buffer = new Buffer(512); 188 _stream.id = "B"; 189 immutable length = params.length.to!ushort; 190 buffer.write0String(""); 191 buffer.write0String(statement); 192 buffer.write!(Endian.bigEndian, ushort)(length); 193 foreach(param ; params) { 194 static if(is(typeof(param) : string)) buffer.write!(Endian.bigEndian, ushort)(false); 195 else buffer.write!(Endian.bigEndian, ushort)(true); 196 } 197 buffer.write!(Endian.bigEndian, ushort)(length); 198 void writeImpl(T)(T value) { 199 auto str = value.to!string; 200 buffer.write!(Endian.bigEndian, uint)(str.length.to!uint); 201 buffer.write(str); 202 } 203 foreach(param ; params) { 204 if(param is null) { 205 buffer.write!(Endian.bigEndian, uint)(uint.max); 206 } else { 207 final switch(param.type) with(Type) { 208 case BOOL: 209 writeImpl(param.to!string[0]); 210 break; 211 case BYTE: 212 case SHORT: 213 case INT: 214 case LONG: 215 case FLOAT: 216 case DOUBLE: 217 case CHAR: 218 case STRING: 219 case CLOB: 220 case DATE: 221 case DATETIME: 222 case TIME: 223 writeImpl(param.to!string); 224 break; 225 case BINARY: 226 case BLOB: 227 writeImpl(cast(string)(cast(Prepared.ParamImpl!(ubyte[], Type.BINARY))param).value); 228 break; 229 } 230 } 231 } 232 buffer.write!(Endian.bigEndian, ushort)(1); 233 buffer.write!(Endian.bigEndian, ushort)(1); 234 _stream.send(buffer); 235 buffer.reset(); 236 _stream.id = "E"; 237 buffer.write0String(""); 238 buffer.write(0); 239 _stream.send(buffer); 240 buffer.reset(); 241 _stream.id = "S"; 242 _stream.send(buffer); 243 receiveAndEnforcePacketSequence('2'); 244 } 245 246 public override Result querySelect(string query) { 247 Result result; 248 this.query(query); 249 Buffer buffer = receive(); 250 if(_stream.id!char[0] != 'C') { 251 enforcePacketSequence('T'); 252 ColumnInfo[] columns; 253 foreach(i ; 0..buffer.read!(Endian.bigEndian, ushort)()) { 254 ColumnInfo column; 255 column.column = buffer.read0String().idup; 256 buffer.readData(6); 257 column.type = buffer.read!(Endian.bigEndian, uint)(); 258 buffer.readData(8); 259 columns ~= column; 260 result.columns[column.column] = i; 261 } 262 while(true) { 263 buffer = receive(); 264 if(_stream.id!char[0] == 'C') break; 265 enforcePacketSequence('D'); 266 enforce!DatabaseConnectionException(buffer.read!(Endian.bigEndian, ushort)() == columns.length, "Length of the row doesn't match the column's"); 267 Result.Row[] rows; 268 foreach(column ; columns) { 269 rows ~= parseRow(column.type, buffer); 270 } 271 result.rows ~= rows; 272 } 273 } 274 enforceReadyForQuery(); 275 return result; 276 } 277 278 private Result.Row parseRow(uint param, Buffer buffer) { 279 switch(param) with(Param) { 280 case BOOL: return readString!bool(buffer); 281 case BYTEA: return readString!(ubyte[])(buffer); 282 case INT8: return readString!long(buffer); 283 case INT2: return readString!short(buffer); 284 case INT4: return readString!int(buffer); 285 case TEXT: return readString!string(buffer); 286 case FLOAT4: return readString!float(buffer); 287 case FLOAT8: return readString!double(buffer); 288 case CHAR: return readString!char(buffer); 289 case VARCHAR: return readString!string(buffer); 290 case DATE: return readString!(std.datetime.Date)(buffer); 291 case TIMESTAMP: return readString!(std.datetime.DateTime)(buffer); 292 case TIME: return readString!(std.datetime.TimeOfDay)(buffer); 293 default: throw new DatabaseConnectionException("Unknwon type with id " ~ param.to!string); 294 } 295 } 296 297 private static struct ColumnInfo { 298 299 string column; 300 uint type; 301 302 } 303 304 // CREATE | ALTER 305 306 protected override TableInfo[string] getTableInfo(string table) { 307 executeQuery(infoStatement, Prepared.prepare(table)); 308 TableInfo[string] ret; 309 while(true) { 310 Buffer buffer = receive(); 311 if(_stream.id!char[0] == 'C') break; 312 enforcePacketSequence('D'); 313 enforce!DatabaseConnectionException(buffer.read!(Endian.bigEndian, ushort)() == 5, "Wrong number of fields returned by the server"); 314 TableInfo field; 315 field.name = buffer.read!string(buffer.read!(Endian.bigEndian, uint)()).idup; 316 field.type = fromStringToType(buffer.read!string(buffer.read!(Endian.bigEndian, uint)())); 317 field.nullable = buffer.read!string(buffer.read!(Endian.bigEndian, uint)()) == "YES"; 318 immutable length = buffer.read!(Endian.bigEndian, uint)(); 319 if(length != uint.max) field.length = buffer.read!(Endian.bigEndian, uint)(); 320 immutable defaultValue = buffer.read!(Endian.bigEndian, uint)(); 321 if(defaultValue != uint.max) field.defaultValue = buffer.read!string(defaultValue).idup; 322 ret[field.name] = field; 323 } 324 enforceReadyForQuery(); 325 return ret; 326 } 327 328 private uint fromStringToType(string str) { 329 switch(str) with(Type) { 330 case "boolean": return BOOL; 331 case "smallint": return SHORT; 332 case "integer": return INT; 333 case "bigint": return LONG; 334 case "real": return FLOAT; 335 case "character": return CHAR; 336 case "double precision": return DOUBLE; 337 case "character varying": return STRING; 338 case "bytea": return BINARY | BLOB; 339 case "text": return CLOB; 340 case "date": return DATE; 341 case "timestamp": case "timestamp without time zone": return DATETIME; 342 case "time": case "time without time zone": return TIME; 343 default: throw new DatabaseException("Unknown type '" ~ str ~ "'"); 344 } 345 } 346 347 protected override string generateField(InitInfo.Field field) { 348 string[] ret = [field.name]; 349 ret ~= fromTypeToString(cast(Type)field.type, field.autoIncrement, field.length); 350 if(field.length) ret[1] ~= "(" ~ field.length.to!string ~ ")"; 351 if(!field.nullable) ret ~= "not null"; 352 if(field.unique) ret ~= "unique"; 353 return ret.join(" "); 354 } 355 356 private string fromTypeToString(Type type, bool autoIncrement, ref size_t length) { 357 final switch(type) with(Type) { 358 case BOOL: return "boolean"; 359 case BYTE: throw new DatabaseException("Type byte is not supported"); 360 case SHORT: return autoIncrement ? "serial2" : "int2"; 361 case INT: return autoIncrement ? "serial4" : "int4"; 362 case LONG: return autoIncrement ? "serial8" : "int8"; 363 case FLOAT: return "float4"; 364 case DOUBLE: return "float8"; 365 case CHAR: 366 length = 1; 367 return "char"; 368 case STRING: return "varchar"; 369 case BINARY: 370 case BLOB: 371 length = 0; // bytea(x) not supported 372 return "bytea"; 373 case CLOB: return "text"; 374 case DATE: return "date"; 375 case DATETIME: return "timestamp"; 376 case TIME: return "time"; 377 } 378 } 379 380 protected override void createTable(string table, string[] fields) { 381 super.createTable(table, fields); 382 receiveAndEnforcePacketSequence('C'); 383 enforceReadyForQuery(); 384 } 385 386 protected override void alterTableColumn(string table, InitInfo.Field field, bool typeChanged, bool nullableChanged) { 387 string q = "alter table " ~ table ~ " alter column " ~ field.name; 388 if(typeChanged) { 389 q ~= " type " ~ fromTypeToString(cast(Type)field.type, false, field.length); 390 if(field.length) q ~= "(" ~ field.length.to!string ~ ")"; 391 } 392 if(nullableChanged) { 393 if(field.nullable) q ~= " drop not null"; 394 else q ~= " set not null"; 395 } 396 query(q ~ ";"); 397 receiveAndEnforcePacketSequence('C'); 398 enforceReadyForQuery(); 399 } 400 401 protected override void alterTableAddColumn(string table, InitInfo.Field field) { 402 super.alterTableAddColumn(table, field); 403 receiveAndEnforcePacketSequence('C'); 404 enforceReadyForQuery(); 405 } 406 407 protected override void alterTableDropColumn(string table, string column) { 408 super.alterTableDropColumn(table, column); 409 receiveAndEnforcePacketSequence('C'); 410 enforceReadyForQuery(); 411 } 412 413 // INSERT 414 415 protected override Result insertImpl(InsertInfo insertInfo) { 416 auto ret = super.insertImpl(insertInfo); 417 receiveAndEnforcePacketSequence('C'); 418 enforceReadyForQuery(); 419 return ret; 420 } 421 422 protected override Result insertInto(string table, string[] names, string[] fields, string[] primaryKeys) { 423 string q = "insert into " ~ table ~ " (" ~ names.join(",") ~ ") values (" ~ fields.join(",") ~ ")"; 424 if(primaryKeys.length) q ~= " returning " ~ primaryKeys.join(","); 425 query(q); 426 if(primaryKeys.length) { 427 Result result; 428 Buffer buffer = receive(); 429 enforce(buffer.read!(Endian.bigEndian, ushort)() == primaryKeys.length, "Wrong number of fields returned by the server"); 430 ColumnInfo[] info; 431 foreach(i ; 0..primaryKeys.length) { 432 ColumnInfo column; 433 column.column = buffer.read0String().idup; 434 buffer.readData(6); // ??? 435 column.type = buffer.read!(Endian.bigEndian, uint)(); 436 buffer.readData(8); // ??? 437 info ~= column; 438 result.columns[column.column] = i; 439 } 440 buffer = receive(); 441 enforce(buffer.read!(Endian.bigEndian, ushort)() == primaryKeys.length, "Wrong number of fields returned by the server"); 442 Result.Row[] rows; 443 foreach(column ; info) { 444 rows ~= parseRow(column.type, buffer); 445 } 446 result.rows ~= rows; 447 return result; 448 } else { 449 return Result.init; 450 } 451 } 452 453 // UPDATE 454 455 protected override void updateImpl(UpdateInfo updateInfo, Clause.Where where) { 456 super.updateImpl(updateInfo, where); 457 receiveAndEnforcePacketSequence('C'); 458 enforceReadyForQuery(); 459 } 460 461 // DELETE 462 463 protected override void deleteImpl(string table, Clause.Where where) { 464 super.deleteImpl(table, where); 465 receiveAndEnforcePacketSequence('C'); 466 enforceReadyForQuery(); 467 } 468 469 // DROP 470 471 public override void dropIfExists(string table) { 472 super.dropIfExists(table); 473 receiveAndEnforcePacketSequence('C'); 474 enforceReadyForQuery(); 475 } 476 477 public override void drop(string table) { 478 super.drop(table); 479 receiveAndEnforcePacketSequence('C'); 480 enforceReadyForQuery(); 481 } 482 483 // UTILS 484 485 private void enforcePacketSequence(char expected) { 486 immutable current = _stream.id!char[0]; 487 if(current != expected) throw new WrongPacketSequenceException(expected, current); 488 } 489 490 private void receiveAndEnforcePacketSequence(char expected) { 491 receive(); 492 enforcePacketSequence(expected); 493 } 494 495 private void enforceReadyForQuery() { 496 Buffer buffer = receive(); 497 enforcePacketSequence('Z'); 498 enforce!DatabaseConnectionException(buffer.data.length == 1 && "ITE".canFind(buffer.read!char()), "Server is not ready for query"); 499 } 500 501 protected override string randomFunction() { 502 return "random()"; 503 } 504 505 protected override string escapeBinary(ubyte[] value) { 506 return "'\\x" ~ toHexString(value) ~ "'"; 507 } 508 509 private enum Param : uint { 510 511 BOOL = 16, 512 BYTEA = 17, 513 INT8 = 20, 514 INT2 = 21, 515 INT4 = 23, 516 TEXT = 25, 517 FLOAT4 = 700, 518 FLOAT8 = 701, 519 CHAR = 1042, 520 VARCHAR = 1043, 521 DATE = 1082, 522 TIME = 1083, 523 TIMESTAMP = 1114, 524 525 } 526 527 private Result.Row readString(T)(Buffer buffer) { 528 immutable length = buffer.read!(Endian.bigEndian, uint)(); 529 if(length == uint.max) return null; 530 else { 531 auto str = buffer.read!string(length).idup; 532 static if(is(T == ubyte[])) { 533 assert(str.length > 2 && str.length % 2 == 0); 534 return Result.Row.from(fromHexString(str[2..$])); 535 } else static if(is(T == std.datetime.Date)) { 536 return Result.Row.from(std.datetime.Date.fromISOExtString(str)); 537 } else static if(is(T == std.datetime.DateTime)) { 538 return Result.Row.from(std.datetime.DateTime.fromISOExtString(str.replace(" ", "T"))); 539 } else static if(is(T == std.datetime.TimeOfDay)) { 540 return Result.Row.from(std.datetime.TimeOfDay.fromISOExtString(str)); 541 } else static if(is(T == bool)) { 542 return Result.Row.from(str == "t"); 543 } else { 544 return Result.Row.from(to!T(str)); 545 } 546 } 547 } 548 549 } 550 551 class WrongPacketSequenceException : DatabaseConnectionException { 552 553 public this(char expected, char got, string file=__FILE__, size_t line=__LINE__) { 554 super("Wrong packet sequence (expected '" ~ expected ~ "' but got '" ~ got ~ "')", file, line); 555 } 556 557 } 558 559 alias PostgresqlDatabaseException = ErrorCodeDatabaseException!("PostgreSQL", char); 560 561 alias PostgresqlDatabaseExceptions = ErrorCodesDatabaseException!PostgresqlDatabaseException;