1 module shark.impl.mysql; 2 3 import std.algorithm : max, canFind; 4 import std.conv : to; 5 import std.exception : enforce; 6 import std.experimental.logger : trace; 7 import std.digest : toHexString; 8 import std.digest.sha : sha1Of, sha256Of; 9 import std.socket; 10 import std..string : join; 11 import std.system : Endian; 12 13 import shark.database : DatabaseConnectionException, ErrorCodeDatabaseException; 14 import shark.sql : SqlDatabase; 15 import shark.util : Stream, read0String, write0String; 16 17 import xbuffer : Buffer; 18 19 // debug 20 import std.stdio; 21 22 enum CharacterSet : ubyte { 23 24 latin1 = 8, 25 latin2 = 9, 26 ascii = 11, 27 utf8 = 33, 28 utf16 = 54, 29 utf16le = 56, 30 utf32 = 60, 31 binary = 63 32 33 } 34 35 private enum CapabilityFlags : uint { 36 37 longPassword = 0x00000001, 38 foundRows = 0x00000002, 39 longFlag = 0x00000004, 40 connectWithDb = 0x00000008, 41 noSchema = 0x000000010, 42 compress = 0x00000020, 43 odbc = 0x00000040, 44 localFiles = 0x00000080, 45 ignoreSpace = 0x00000100, 46 protocol41 = 0x00000200, 47 interactive = 0x00000400, 48 ssl = 0x00000800, 49 ignoreSigpipe = 0x00001000, 50 transactions = 0x00002000, 51 reserved = 0x00004000, 52 secureConnection = 0x00008000, 53 multiStatements = 0x00010000, 54 multiResults = 0x00020000, 55 psMultiResults = 0x00040000, 56 pluginAuth = 0x00080000, 57 58 } 59 60 private alias MysqlStream = Stream!(0, Endian.littleEndian, 3, false, Endian.littleEndian, ubyte); 61 62 class MysqlDatabase : SqlDatabase { 63 64 private immutable ubyte characterSet; 65 66 private MysqlStream _stream; 67 private void[] _buffer; 68 69 private string _serverVersion; 70 71 public this(string host, ushort port=3306, ubyte characterSet=CharacterSet.utf8) { 72 this.characterSet = characterSet; 73 Socket socket = new TcpSocket(); 74 socket.blocking = true; 75 socket.connect(getAddress(host, port)[0]); 76 _stream = new MysqlStream(socket, 1024); 77 } 78 79 /** 80 * Gets MySQL server's version as indicated in the handshake 81 * process by the server. 82 */ 83 public @property string serverVersion() { 84 return _serverVersion; 85 } 86 87 protected override void connectImpl(string db, string user, string password) { 88 Buffer buffer = _stream.receive(); 89 enforce!DatabaseConnectionException(buffer.read!ubyte() == 0x0a, "Incompatible protocols"); 90 _serverVersion = buffer.read0String().idup; 91 buffer.readData(4); // connection id 92 ubyte[] authPluginData = buffer.read!(ubyte[])(8).dup; 93 buffer.readData(1); // filler 94 uint capabilities = buffer.read!(Endian.littleEndian, ushort)(); 95 buffer.read!ubyte(); // character set 96 buffer.readData(2); // status flags 97 capabilities |= (buffer.read!(Endian.littleEndian, ushort)() << 16); 98 immutable authPluginDataLength = buffer.read!byte(); 99 buffer.readData(10); // reserved 100 if(capabilities & CapabilityFlags.secureConnection) { 101 authPluginData ~= buffer.read!(ubyte[])(max(13, authPluginDataLength - 8)).dup; 102 authPluginData = authPluginData[0..$-1]; // remove final 0 103 } 104 string method; 105 if(capabilities & CapabilityFlags.pluginAuth) { 106 method = buffer.read0String().idup; 107 enforce!DatabaseConnectionException(["mysql_native_password", "caching_sha2_password"].canFind(method), "Unknown hashing method '" ~ method ~ "'"); 108 } 109 enforce!DatabaseConnectionException(capabilities & CapabilityFlags.protocol41, "Server does not support protocol v4.1"); 110 buffer.reset(); 111 buffer.write!(Endian.littleEndian, uint)(CapabilityFlags.protocol41 | CapabilityFlags.connectWithDb | CapabilityFlags.secureConnection | CapabilityFlags.pluginAuth); 112 buffer.write!(Endian.littleEndian, uint)(1); 113 buffer.write(characterSet); 114 buffer.writeData(new void[23]); // reserved 115 buffer.write0String(user); 116 if(password.length) { 117 immutable hash = method == "mysql_native_password" ? hashPassword!sha1Of(password, authPluginData) : hashPassword!sha256Of(password, authPluginData); 118 buffer.write(hash.length.to!ubyte); 119 buffer.write(hash); 120 } else { 121 buffer.write(ubyte(0)); 122 } 123 buffer.write0String(db); 124 buffer.write0String(method); 125 _stream.send(buffer); 126 } 127 128 private string hashPassword(alias method)(string password, const(ubyte)[] nonce) { 129 auto password1 = method(password); 130 auto res = method(method(password1), nonce).dup; 131 foreach(i, ref r; res) { 132 r = r ^ password1[i]; 133 } 134 return cast(string)res; 135 } 136 137 protected override void closeImpl() { 138 _stream.socket.close(); 139 } 140 141 private Buffer receive() { 142 Buffer buffer = _stream.receive(); 143 if(buffer.peek!ubyte() == 0xff) { 144 buffer.readData(1); 145 immutable errorCode = buffer.read!(Endian.littleEndian, ushort)(); 146 buffer.readData(6); 147 throw new MysqlDatabaseException(errorCode, cast(string)buffer.data); 148 } 149 return buffer; 150 } 151 152 public override void query(string query) { 153 trace("Running query `" ~ query ~ "`"); 154 Buffer buffer = new Buffer(query.length + 5); 155 buffer.write(ubyte(3)); 156 buffer.write(query); 157 _stream.resetSequence(); 158 _stream.send(buffer); 159 buffer = receive(); 160 buffer.data.writeln; 161 //return receive(); 162 } 163 164 public override Result querySelect(string query) { 165 throw new Exception("Not implemented"); 166 } 167 168 protected override TableInfo[string] getTableInfo(string table) { 169 //query("describe " ~ table ~ ";"); 170 return null; 171 } 172 173 protected override string generateField(InitInfo.Field field) { 174 string[] ret = [field.name]; 175 ret ~= convertType(cast(Type)field.type) ~ (field.length ? "(" ~ field.length.to!string ~ ")" : ""); 176 if(field.autoIncrement) ret ~= "auto_increment"; 177 if(!field.nullable) ret ~= "not null"; 178 if(field.unique) ret ~= "unique"; 179 return ret.join(" "); 180 } 181 182 private string convertType(Type type) { 183 final switch(type) with(Type) { 184 case BOOL: return "boolean"; 185 case BYTE: return "tinyint"; 186 case SHORT: return "smallint"; 187 case INT: return "int"; 188 case LONG: return "bigint"; 189 case FLOAT: return "float"; 190 case DOUBLE: return "double"; 191 case CHAR: return "char"; 192 case STRING: return "varchar"; 193 case BINARY: return "binary"; 194 case CLOB: return "clob"; 195 case BLOB: return "blob"; 196 case DATE: return "date"; 197 case DATETIME: return "datetime"; 198 case TIME: return "time"; 199 } 200 } 201 202 protected override void alterTableColumn(string table, InitInfo.Field field, bool typeChanged, bool nullableChanged) { 203 query("alter table " ~ table ~ " modify column " ~ generateField(field) ~ ";"); 204 } 205 206 protected override Result insertInto(string table, string[] names, string[] fields, string[] primaryKeys) { 207 throw new Exception("Not implemented"); 208 } 209 210 // UTILS 211 212 protected override string randomFunction() { 213 return "rand()"; 214 } 215 216 protected override string escapeBinary(ubyte[] value) { 217 return "0x" ~ toHexString(value); 218 } 219 220 } 221 222 alias MysqlDatabaseException = ErrorCodeDatabaseException!("MySQL", ushort);