1 module shark.database; 2 3 import std.conv : to; 4 static import std.datetime; 5 import std.exception : enforce; 6 import std.string : join, split, strip, startsWith; 7 import std.traits : hasUDA, getUDAs; 8 9 import shark.clause; 10 import shark.entity; 11 import shark.util : toSnakeCase; 12 13 // debug 14 import std.stdio : writeln; 15 16 /** 17 * Represents a generic database. 18 */ 19 class Database { 20 21 private string _db = null; 22 23 /** 24 * Performs authentication and connect to a database. 25 * It is possible to reconnect using the same object by calling 26 * `close` and then `connect` again. 27 */ 28 public void connect(string db, string user, string password="") { 29 _db = db; 30 this.connectImpl(db, user, password); 31 } 32 33 /// ditto 34 public void connect(string password="") { 35 this.connect(null, "", password); 36 } 37 38 protected abstract void connectImpl(string db, string user, string password); 39 40 /** 41 * Indicates the name of the database opened or null 42 * if the database isn't connected. 43 */ 44 public @property string db() { 45 return _db; 46 } 47 48 /** 49 * Closes the connection with the database. 50 * Should only be called after `connect`. 51 */ 52 public void close() { 53 closeImpl(); 54 _db = null; 55 } 56 57 protected abstract void closeImpl(); 58 59 // INIT 60 61 /** 62 * Initializes an entity, either by creating it or updating 63 * its fields when it already exists. 64 * Example: 65 * --- 66 * class Test : Entity { 67 * 68 * override string tableName() { 69 * return "test"; 70 * } 71 * 72 * @PrimaryKey 73 * @AutoIncrement 74 * Long testId; 75 * 76 * @NotNull 77 * Integer a; 78 * 79 * @Length(10) 80 * String b; 81 * 82 * } 83 * 84 * database.init!Test(); 85 * --- 86 */ 87 public void init(T:Entity)() { 88 enum initInfo = generateInitInfo!T(); // generate at compile time 89 initImpl(initInfo); 90 } 91 92 private static InitInfo generateInitInfo(T:Entity)() { 93 InitInfo ret; 94 ret.tableName = new T().tableName; 95 static foreach(immutable member ; getEntityMembers!T) { 96 { 97 InitInfo.Field field; 98 field.name = memberName!(T, member); 99 field.type = memberType!(typeof(__traits(getMember, T, member))); 100 field.nullable = memberNullable!(typeof(__traits(getMember, T, member))); 101 foreach(uda ; __traits(getAttributes, __traits(getMember, T, member))) { 102 static if(is(uda == PrimaryKey)) { 103 ret.primaryKeys ~= field.name; 104 field.nullable = false; 105 } else static if(is(uda == AutoIncrement)) { 106 field.autoIncrement = true; 107 field.nullable = false; 108 } 109 else static if(is(uda == NotNull)) field.nullable = false; 110 else static if(is(uda == Unique)) field.unique = true; 111 else static if(is(typeof(uda) == Length)) field.length = uda.length; 112 } 113 ret.fields ~= field; 114 } 115 } 116 return ret; 117 } 118 119 protected abstract void initImpl(InitInfo); 120 121 protected static struct InitInfo { 122 123 string tableName; 124 Field[] fields; 125 string[] primaryKeys; 126 127 static struct Field { 128 129 string name; 130 uint type; 131 size_t length = 0; 132 bool nullable = true; 133 bool unique = false; 134 bool autoIncrement = false; 135 string defaultValue; 136 137 } 138 139 } 140 141 // SELECT 142 143 /** 144 * Clauses for select. 145 * It is possible to instantiate the struct with the parameters 146 * in any order. 147 * Example: 148 * --- 149 * Select(Clause.OrderBy.random, Clause.Limit(12)); 150 * --- 151 */ 152 public static struct Select { 153 154 Clause.Where where; 155 156 Clause.Order order; 157 158 Clause.Limit limit; 159 160 // calling this because of a bug in the default constructor 161 private void set(Clause.Where where, Clause.Order order, Clause.Limit limit) { 162 this.where = where; 163 this.order = order; 164 this.limit = limit; 165 } 166 167 this(Clause.Where where, Clause.Order=Clause.Order.init, Clause.Limit=Clause.Limit.init) { 168 set(where, order, limit); 169 } 170 171 this(Clause.Where where, Clause.Limit limit, Clause.Order=Clause.Order.init) { 172 set(where, order, limit); 173 } 174 175 this(Clause.Order order, Clause.Where=Clause.Where.init, Clause.Limit=Clause.Limit.init) { 176 set(where, order, limit); 177 } 178 179 this(Clause.Order order, Clause.Limit limit, Clause.Where=Clause.Where.init) { 180 set(where, order, limit); 181 } 182 183 this(Clause.Limit limit, Clause.Where where, Clause.Order=Clause.Order.init) { 184 set(where, order, limit); 185 } 186 187 this(Clause.Limit limit, Clause.Order order, Clause.Where=Clause.Where.init) { 188 set(where, order, limit); 189 } 190 191 } 192 193 /** 194 * Selects entities from the database using the optional given clauses. 195 * The name of the fields should correspond to the name of the entity's 196 * fields in the database, not to the ones in the D program. 197 * Example: 198 * --- 199 * database.select!(["a", "b"], Test)(); 200 * database.select!("testId", Test)(); 201 * database.select!Test(Database.Select(Database.Select.Limit(10))); 202 * --- 203 */ 204 public T[] select(string[] fields, T:Entity)(Select select=Select.init) { 205 return selectImpl!(fields, T)(select); 206 } 207 208 /// ditto 209 public T[] select(string field, T:Entity)(Select select=Select.init) { 210 return selectImpl!([field], T)(select); 211 } 212 213 /// ditto 214 public T[] select(T:Entity)(Select select=Select.init) { 215 return selectImpl!([], T)(select); 216 } 217 218 private T[] selectImpl(string[] fields, T:Entity)(Select select=Select.init) { 219 SelectInfo selectInfo; 220 selectInfo.tableName = new T().tableName; 221 //static foreach(field ; fields) selectInfo.fields ~= memberName!(T, field); 222 selectInfo.fields = fields; 223 return selectImpl(selectInfo, select).bind!T(); 224 } 225 226 /** 227 * Selects one entity from the database. 228 */ 229 public T selectOne(string[] fields, T:Entity)(Select select=Select.init) { 230 return selectOneImpl!(fields, T)(select); 231 } 232 233 /// ditto 234 public T selectOne(string field, T:Entity)(Select select=Select.init) { 235 return selectOneImpl!([field], T)(select); 236 } 237 238 /// ditto 239 public T selectOne(T:Entity)(Select select=Select.init) { 240 return selectOneImpl!([], T)(select); 241 } 242 243 private T selectOneImpl(string[] fields, T:Entity)(Select select=Select.init) { 244 select.limit = Clause.Limit(1); 245 T[] ret = this.select!(fields, T)(select); 246 if(ret.length) return ret[0]; 247 else return null; 248 } 249 250 /** 251 * Selects an entity from its primary key(s). 252 */ 253 public T selectId(string[] fields, T:Entity)(T entity, Select select=Select.init) if(getEntityPrimaryKeys!T.length) { 254 return selectIdImpl!(fields, T)(entity, select); 255 } 256 257 /// ditto 258 public T selectId(string field, T:Entity)(T entity, Select select=Select.init) if(getEntityPrimaryKeys!T.length) { 259 return selectIdImpl!([field], T)(entity, select); 260 } 261 262 /// ditto 263 public T selectId(T:Entity)(T entity, Select select=Select.init) if(getEntityPrimaryKeys!T.length) { 264 return selectIdImpl!([], T)(entity, select); 265 } 266 267 private T selectIdImpl(string[] fields, T:Entity)(T entity, Select select=Select.init) if(getEntityPrimaryKeys!T.length) { 268 return selectOne!(fields, T)(Select(makeWhereFromPrimaryKeys(entity))); 269 } 270 271 protected abstract Result selectImpl(SelectInfo, Select); 272 273 protected static struct SelectInfo { 274 275 string tableName; 276 string[] fields; 277 278 } 279 280 // INSERT 281 282 /** 283 * Inserts a new entity into the database. 284 * This method does not alter the entity: to update it after an insert 285 * use select. 286 * Example: 287 * --- 288 * Test test = new Test(); 289 * test.a = 55; 290 * test.b = "Test"; 291 * database.insert(test); 292 * --- 293 */ 294 public void insert(T:Entity)(T entity, bool updateId=true) { 295 Result result = insertImpl(generateInsertInfo(entity, updateId)); 296 foreach(row ; result.rows) result.apply(entity, row); 297 } 298 299 private InsertInfo generateInsertInfo(T:Entity)(T entity, bool updateId) { 300 InsertInfo ret; 301 ret.tableName = entity.tableName; 302 static foreach(immutable member ; getEntityMembers!T) { 303 { 304 static if(hasUDA!(__traits(getMember, T, member), PrimaryKey)) if(updateId) ret.primaryKeys ~= memberName!(T, member); 305 static if(!memberNullable!(typeof(__traits(getMember, T, member)))) enum condition = "true"; 306 else enum condition = "!entity." ~ member ~ ".isNull"; 307 if(mixin(condition)) { 308 InsertInfo.Field field; 309 field.name = memberName!(T, member); 310 field.value = escape(mixin("entity." ~ member)); 311 ret.fields ~= field; 312 } 313 } 314 } 315 return ret; 316 } 317 318 protected abstract Result insertImpl(InsertInfo); 319 320 protected static struct InsertInfo { 321 322 string tableName; 323 Field[] fields; 324 string[] primaryKeys; 325 326 static struct Field { 327 328 string name; 329 string value; 330 331 } 332 333 } 334 335 // UPDATE 336 337 /** 338 * Updates one or more table fields. 339 * The given fields should correspond to the ones in the entity class, 340 * not to the ones in the database. 341 */ 342 public void update(string[] fields, T:Entity)(T entity, Clause.Where where) if(fields.length) { 343 UpdateInfo updateInfo; 344 updateInfo.tableName = entity.tableName; 345 static foreach(field ; fields) { 346 updateInfo.fields ~= UpdateInfo.Field(memberName!(T, field), escape(mixin("entity." ~ field))); 347 } 348 updateImpl(updateInfo, where); 349 } 350 351 /// ditto 352 public void update(string field, T:Entity)(T entity, Clause.Where where) { 353 return update!([field], T)(entity, where); 354 } 355 356 /** 357 * Updates a single row of a table searching by the entity's 358 * primary key(s). 359 */ 360 public void update(string[] fields, T:Entity)(T entity) if(getEntityPrimaryKeys!T.length) { 361 update!(fields, T)(entity, makeWhereFromPrimaryKeys(entity)); 362 } 363 364 /// ditto 365 public void update(string field, T:Entity)(T entity) if(getEntityPrimaryKeys!T.length) { 366 return update!([field], T)(entity); 367 } 368 369 protected abstract void updateImpl(UpdateInfo, Clause.Where); 370 371 protected static struct UpdateInfo { 372 373 string tableName; 374 Field[] fields; 375 376 static struct Field { 377 378 string name; 379 string value; 380 381 } 382 383 } 384 385 // DELETE 386 387 /** 388 * Deletes row from a table. 389 */ 390 public void del(string table, Clause.Where where) { 391 deleteImpl(table, where); 392 } 393 394 /** 395 * Deletes zero or one row using the entity's primary key(s). 396 * Example: 397 * --- 398 * database.del(test); 399 * --- 400 */ 401 public void del(T:Entity)(T entity) if(getEntityPrimaryKeys!T.length) { 402 del(entity.tableName, makeWhereFromPrimaryKeys(entity)); 403 } 404 405 protected abstract void deleteImpl(string, Clause.Where); 406 407 // DROP 408 409 /** 410 * Deletes a table if exists. 411 */ 412 public abstract void dropIfExists(string table); 413 414 /** 415 * Deletes a table. May throw an exception if the table 416 * does not exist. 417 */ 418 public abstract void drop(string table); 419 420 // UTILS 421 422 private static Clause.Where makeWhereFromPrimaryKeys(T)(T entity) if(getEntityPrimaryKeys!T.length) { 423 Clause.Where where; 424 Clause.Where.GenericStatement[] statements; 425 static foreach(immutable member ; getEntityPrimaryKeys!T) { 426 statements ~= new Clause.Where.Statement(memberName!(T, member), Clause.Where.Operator.equals, mixin("entity." ~ member)); 427 } 428 where.statement = statements[0]; 429 foreach(statement ; statements[1..$]) { 430 where.statement = new Clause.Where.ComplexStatement(where.statement, Clause.Where.Glue.and, statement); 431 } 432 return where; 433 } 434 435 /** 436 * Result of a select query. 437 */ 438 public static struct Result { 439 440 size_t[string] columns; // position in the array of the column 441 442 Row[][] rows; 443 444 /** 445 * Creates n objects from the result. T doesn't have to 446 * extend `Entity`. 447 * Example: 448 * --- 449 * class Test { 450 * 451 * String a; 452 * 453 * Integer b; 454 * 455 * } 456 * ... 457 * Test[] entities = result.bind!Test(); 458 * --- 459 */ 460 public T[] bind(T)() { 461 T[] ret; 462 foreach(row ; rows) { 463 T entity; 464 static if(is(T == class)) entity = new T(); 465 apply(entity, row); 466 ret ~= entity; 467 } 468 return ret; 469 } 470 471 /** 472 * Applies the result of one row to entity, passed by reference. 473 * The entitty doesn't have to extend `Entity`. 474 */ 475 public void apply(T)(ref T entity, Row[] row) { 476 static foreach(immutable member ; getEntityMembers!T) { 477 { 478 auto ptr = memberName!(T, member) in columns; 479 if(ptr) { 480 auto v = row[*ptr]; 481 if(v is null) { 482 static if(memberNullable!(typeof(__traits(getMember, T, member)))) mixin("entity." ~ member).nullify(); 483 else throw new DatabaseException("Could not nullify " ~ T.stringof ~ "." ~ member); 484 } else { 485 alias R = typeof(__traits(getMember, T, member)); 486 static if(is(R == Bool) || is(R == bool)) { 487 auto value = cast(Result.RowImpl!bool)v; 488 } else static if(is(R == Byte) || is(R == byte) || is(R == ubyte)) { 489 auto value = cast(Result.RowImpl!byte)v; 490 } else static if(is(R == Short) || is(R == short) || is(R == ushort)) { 491 auto value = cast(Result.RowImpl!short)v; 492 } else static if(is(R == Integer) || is(R == int) || is(R == uint)) { 493 auto value = cast(Result.RowImpl!int)v; 494 } else static if(is(R == Long) || is(R == long) || is(R == ulong)) { 495 auto value = cast(Result.RowImpl!long)v; 496 } else static if(is(R == Float) || is(R == float)) { 497 auto value = cast(Result.RowImpl!float)v; 498 } else static if(is(R == Double) || is(R == double)) { 499 auto value = cast(Result.RowImpl!double)v; 500 } else static if(is(R == Char) || is(R == char)) { 501 auto value = cast(Result.RowImpl!char)v; 502 } else static if(is(R == String) || is(R == Clob) || is(R == string)) { 503 auto value = cast(Result.RowImpl!string)v; 504 } else static if(is(R == Binary) || is(R == Blob) || is(R == ubyte[])) { 505 auto value = cast(Result.RowImpl!(ubyte[]))v; 506 } else static if(is(R == Date) || is(R == std.datetime.Date)) { 507 auto value = cast(Result.RowImpl!(std.datetime.Date))v; 508 } else static if(is(R == DateTime) || is(R == std.datetime.DateTime)) { 509 auto value = cast(Result.RowImpl!(std.datetime.DateTime))v; 510 } else static if(is(R == Time) || is(R == std.datetime.TimeOfDay)) { 511 auto value = cast(Result.RowImpl!(std.datetime.TimeOfDay))v; 512 } 513 if(value is null) throw new DatabaseException("Could not cast " ~ row[*ptr].toString() ~ " to " ~ R.stringof); 514 mixin("entity." ~ member) = value.value; 515 } 516 } 517 } 518 } 519 } 520 521 static class Row { 522 523 static Row from(T)(T value) { 524 RowImpl!T ret = new RowImpl!T(); 525 ret.value = value; 526 return ret; 527 } 528 529 } 530 531 static class RowImpl(T) : Row { 532 533 T value; 534 535 override string toString() { 536 import std.conv; 537 return value.to!string; 538 } 539 540 } 541 542 } 543 544 protected enum Type : uint { 545 546 BOOL = 1 << 0, 547 BYTE = 1 << 1, 548 SHORT = 1 << 2, 549 INT = 1 << 3, 550 LONG = 1 << 4, 551 FLOAT = 1 << 5, 552 DOUBLE = 1 << 6, 553 CHAR = 1 << 7, 554 STRING = 1 << 8, 555 BINARY = 1 << 9, 556 CLOB = 1 << 10, 557 BLOB = 1 << 11, 558 DATE = 1 << 12, 559 DATETIME = 1 << 13, 560 TIME = 1 << 14, 561 562 } 563 564 protected string escape(T)(T value) { 565 static if(is(T == Char) || is(T == char)) { 566 return escapeString([value]); 567 } else static if(is(T == String) || is(T == Clob) || is(T : string)) { 568 return escapeString(value); 569 } else static if(is(T == Binary) || is(T == Blob) || is(T : ubyte[])) { 570 return escapeBinary(value); 571 } else static if(is(T == Date) || is(T == std.datetime.Date)) { 572 return escapeDate(value); 573 } else static if(is(T == DateTime) || is(T == std.datetime.DateTime)) { 574 return escapeDateTime(value); 575 } else static if(is(T == Time) || is(T == std.datetime.TimeOfDay)) { 576 return escapeTime(value); 577 } else static if(is(T : Nullable!R, R)) { 578 if(value.isNull) return "null"; 579 else return value.value.to!string; 580 } else { 581 return value.to!string; 582 } 583 } 584 585 protected abstract string escapeString(string); 586 587 protected abstract string escapeBinary(ubyte[]); 588 589 protected abstract string escapeDate(std.datetime.Date); 590 591 protected abstract string escapeDateTime(std.datetime.DateTime); 592 593 protected abstract string escapeTime(std.datetime.TimeOfDay); 594 595 } 596 597 /** 598 * Indicates whether an entity is valid. 599 * A valid entity extends `Entity`, is not abstract, has an 600 * empty constructor and no duplicated members. 601 */ 602 public bool isValidEntity(T)() { 603 static if(!is(T : Entity) || !__traits(compiles, new T())) { 604 return false; 605 } else { 606 import std.algorithm : sort, uniq; 607 import std.array : array; 608 string[] members; 609 static foreach(immutable member ; getEntityMembers!T) members ~= memberName!(T, member); 610 sort(members); 611 return uniq(members).array.length == members.length; 612 } 613 } 614 615 /// 616 unittest { 617 618 static struct Invalid0 {} 619 620 static class Invalid1 {} 621 622 static abstract class Invalid2 {} 623 624 static class Invalid3 : Entity { 625 626 override string tableName() { return "test"; } 627 628 this(int i) {} 629 630 } 631 632 static class Invalid4 : Entity { 633 634 override string tableName() { return "test"; } 635 636 String test; 637 638 @Name("test") 639 String test0; 640 641 } 642 643 static assert(!isValidEntity!Invalid0); 644 static assert(!isValidEntity!Invalid1); 645 static assert(!isValidEntity!Invalid2); 646 static assert(!isValidEntity!Invalid3); 647 static assert(!isValidEntity!Invalid4); 648 649 } 650 651 private string memberName(T:Entity, string member)() { 652 static if(hasUDA!(__traits(getMember, T, member), Name)) { 653 return getUDAs!(__traits(getMember, T, member), Name)[0].name; 654 } else { 655 return member.toSnakeCase(); 656 } 657 } 658 659 private Database.Type memberType(T)() { 660 with(Database.Type) { 661 static if(is(T == Bool) || is(T == bool)) { 662 return BOOL; 663 } else static if(is(T == Byte) || is(T == byte) || is(T == ubyte)) { 664 return BYTE; 665 } else static if(is(T == Short) || is(T == short) || is(T == ushort)) { 666 return SHORT; 667 } else static if(is(T == Integer) || is(T == int) || is(T == uint)) { 668 return INT; 669 } else static if(is(T == Long) || is(T == long) || is(T == ulong)) { 670 return LONG; 671 } else static if(is(T == Float) || is(T == float)) { 672 return FLOAT; 673 } else static if(is(T == Double) || is(T == double)) { 674 return DOUBLE; 675 } else static if(is(T == Char) || is(T == char)) { 676 return CHAR; 677 } else static if(is(T == Clob)) { 678 return CLOB; 679 } else static if(is(T == Blob)) { 680 return BLOB; 681 } else static if(is(T == String) || is(T : string)) { 682 return STRING; 683 } else static if(is(T == Binary) || is(T : ubyte[])) { 684 return BINARY; 685 } else static if(is(T == Date) || is(T == std.datetime.Date)) { 686 return DATE; 687 } else static if(is(T == DateTime) || is(T == std.datetime.DateTime)) { 688 return DATETIME; 689 } else static if(is(T == Time) || is(T == std.datetime.TimeOfDay)) { 690 return TIME; 691 } else { 692 static assert(0, "Member of type " ~ T.stringof ~ " is not valid"); 693 } 694 } 695 } 696 697 private bool memberNullable(T)() { 698 static if(is(T : Nullable!R, R)) return true; 699 else return false; 700 } 701 702 private string[] getEntityMembers(T:Entity)() { 703 string[] ret; 704 foreach(immutable member ; __traits(allMembers, T)) { 705 static if(!is(typeof(__traits(getMember, T, member)) == function) && __traits(compiles, mixin("new T()." ~ member ~ "=T." ~ member ~ ".init"))) { 706 ret ~= member; 707 } 708 } 709 return ret; 710 } 711 712 private string[] getEntityPrimaryKeys(T:Entity)() { 713 string[] ret; 714 static foreach(immutable member ; getEntityMembers!T) { 715 static if(hasUDA!(__traits(getMember, T, member), PrimaryKey)) ret ~= member; 716 } 717 return ret; 718 } 719 720 /** 721 * Generic database exception. 722 */ 723 class DatabaseException : Exception { 724 725 public this(string msg, string file=__FILE__, size_t line=__LINE__) { 726 super(msg, file, line); 727 } 728 729 } 730 731 /** 732 * Exception thrown when an error occurs during the connection, 733 * like an unexpected or malformed packet. 734 */ 735 class DatabaseConnectionException : DatabaseException { 736 737 public this(string msg, string file=__FILE__, size_t line=__LINE__) { 738 super(msg, file, line); 739 } 740 741 } 742 743 class ErrorCodeDatabaseException(string dbname, T) : DatabaseException { 744 745 private T _errorCode; 746 747 public this(T errorCode, string msg, string file=__FILE__, size_t line=__LINE__) { 748 super("(" ~ dbname ~ "-" ~ errorCode.to!string ~ ") " ~ msg, file, line); 749 _errorCode = errorCode; 750 } 751 752 public @property T errorCode() { 753 return _errorCode; 754 } 755 756 } 757 758 class ErrorCodesDatabaseException(T:Exception) : DatabaseException { 759 760 private T[] _errors; 761 762 public this(T[] errors, string file=__FILE__, size_t line=__LINE__) { 763 string[] messages; 764 foreach(error ; errors) { 765 messages ~= error.msg; 766 } 767 super(messages.join(", "), file, line); 768 } 769 770 public @property T[] errors() { 771 return _errors; 772 } 773 774 }