1 module shark.impl.mysql;
2 
3 import std.algorithm : max, canFind;
4 import std.conv : to;
5 import std.exception : enforce;
6 import std.experimental.logger : trace;
7 import std.digest : toHexString;
8 import std.digest.sha : sha1Of, sha256Of;
9 import std.socket;
10 import std..string : join;
11 import std.system : Endian;
12 
13 import shark.database : DatabaseConnectionException, ErrorCodeDatabaseException;
14 import shark.sql : SqlDatabase;
15 import shark.util : Stream, read0String, write0String;
16 
17 import xbuffer : Buffer;
18 
19 // debug
20 import std.stdio;
21 
22 enum CharacterSet : ubyte {
23 
24 	latin1 = 8,
25 	latin2 = 9,
26 	ascii = 11,
27 	utf8 = 33,
28 	utf16 = 54,
29 	utf16le = 56,
30 	utf32 = 60,
31 	binary = 63
32 
33 }
34 
35 private enum CapabilityFlags : uint {
36 
37 	longPassword = 0x00000001,
38 	foundRows = 0x00000002,
39 	longFlag = 0x00000004,
40 	connectWithDb = 0x00000008,
41 	noSchema = 0x000000010,
42 	compress = 0x00000020,
43 	odbc = 0x00000040,
44 	localFiles = 0x00000080,
45 	ignoreSpace = 0x00000100,
46 	protocol41 = 0x00000200,
47 	interactive = 0x00000400,
48 	ssl = 0x00000800,
49 	ignoreSigpipe = 0x00001000,
50 	transactions = 0x00002000,
51 	reserved = 0x00004000,
52 	secureConnection = 0x00008000,
53 	multiStatements = 0x00010000,
54 	multiResults = 0x00020000,
55 	psMultiResults = 0x00040000,
56 	pluginAuth = 0x00080000,
57 
58 }
59 
60 private alias MysqlStream = Stream!(0, Endian.littleEndian, 3, false, Endian.littleEndian, ubyte);
61 
62 class MysqlDatabase : SqlDatabase {
63 
64 	private immutable ubyte characterSet;
65 
66 	private MysqlStream _stream;
67 	private void[] _buffer;
68 
69 	private string _serverVersion;
70 
71 	public this(string host, ushort port=3306, ubyte characterSet=CharacterSet.utf8) {
72 		this.characterSet = characterSet;
73 		Socket socket = new TcpSocket();
74 		socket.blocking = true;
75 		socket.connect(getAddress(host, port)[0]);
76 		_stream = new MysqlStream(socket, 1024);
77 	}
78 
79 	/**
80 	 * Gets MySQL server's version as indicated in the handshake
81 	 * process by the server.
82 	 */
83 	public @property string serverVersion() {
84 		return _serverVersion;
85 	}
86 	
87 	protected override void connectImpl(string db, string user, string password) {
88 		Buffer buffer = _stream.receive();
89 		enforce!DatabaseConnectionException(buffer.read!ubyte() == 0x0a, "Incompatible protocols");
90 		_serverVersion = buffer.read0String().idup;
91 		buffer.readData(4); // connection id
92 		ubyte[] authPluginData = buffer.read!(ubyte[])(8).dup;
93 		buffer.readData(1); // filler
94 		uint capabilities = buffer.read!(Endian.littleEndian, ushort)();
95 		buffer.read!ubyte(); // character set
96 		buffer.readData(2); // status flags
97 		capabilities |= (buffer.read!(Endian.littleEndian, ushort)() << 16);
98 		immutable authPluginDataLength = buffer.read!byte();
99 		buffer.readData(10); // reserved
100 		if(capabilities & CapabilityFlags.secureConnection) {
101 			authPluginData ~= buffer.read!(ubyte[])(max(13, authPluginDataLength - 8)).dup;
102 			authPluginData = authPluginData[0..$-1]; // remove final 0
103 		}
104 		string method;
105 		if(capabilities & CapabilityFlags.pluginAuth) {
106 			method = buffer.read0String().idup;
107 			enforce!DatabaseConnectionException(["mysql_native_password", "caching_sha2_password"].canFind(method), "Unknown hashing method '" ~ method ~ "'");
108 		}
109 		enforce!DatabaseConnectionException(capabilities & CapabilityFlags.protocol41, "Server does not support protocol v4.1");
110 		buffer.reset();
111 		buffer.write!(Endian.littleEndian, uint)(CapabilityFlags.protocol41 | CapabilityFlags.connectWithDb | CapabilityFlags.secureConnection | CapabilityFlags.pluginAuth);
112 		buffer.write!(Endian.littleEndian, uint)(1);
113 		buffer.write(characterSet);
114 		buffer.writeData(new void[23]); // reserved
115 		buffer.write0String(user);
116 		if(password.length) {
117 			immutable hash = method == "mysql_native_password" ? hashPassword!sha1Of(password, authPluginData) : hashPassword!sha256Of(password, authPluginData);
118 			buffer.write(hash.length.to!ubyte);
119 			buffer.write(hash);
120 		} else {
121 			buffer.write(ubyte(0));
122 		}
123 		buffer.write0String(db);
124 		buffer.write0String(method);
125 		_stream.send(buffer);
126 	}
127 	
128 	private string hashPassword(alias method)(string password, const(ubyte)[] nonce) {
129 		auto password1 = method(password);
130 		auto res = method(method(password1), nonce).dup;
131 		foreach(i, ref r; res) {
132 			r = r ^ password1[i];
133 		}
134 		return cast(string)res;
135 	}
136 
137 	protected override void closeImpl() {
138 		_stream.socket.close();
139 	}
140 	
141 	private Buffer receive() {
142 		Buffer buffer = _stream.receive();
143 		if(buffer.peek!ubyte() == 0xff) {
144 			buffer.readData(1);
145 			immutable errorCode = buffer.read!(Endian.littleEndian, ushort)();
146 			buffer.readData(6);
147 			throw new MysqlDatabaseException(errorCode, cast(string)buffer.data);
148 		}
149 		return buffer;
150 	}
151 
152 	public override void query(string query) {
153 		trace("Running query `" ~ query ~ "`");
154 		Buffer buffer = new Buffer(query.length + 5);
155 		buffer.write(ubyte(3));
156 		buffer.write(query);
157 		_stream.resetSequence();
158 		_stream.send(buffer);
159 		buffer = receive();
160 		buffer.data.writeln;
161 		//return receive();
162 	}
163 
164 	public override Result querySelect(string query) {
165 		throw new Exception("Not implemented");
166 	}
167 
168 	protected override TableInfo[string] getTableInfo(string table) {
169 		//query("describe " ~ table ~ ";");
170 		return null;
171 	}
172 
173 	protected override string generateField(InitInfo.Field field) {
174 		string[] ret = [field.name];
175 		ret ~= convertType(cast(Type)field.type) ~ (field.length ? "(" ~ field.length.to!string ~ ")" : "");
176 		if(field.autoIncrement) ret ~= "auto_increment";
177 		if(!field.nullable) ret ~= "not null";
178 		if(field.unique) ret ~= "unique";
179 		return ret.join(" ");
180 	}
181 
182 	private string convertType(Type type) {
183 		final switch(type) with(Type) {
184 			case BOOL: return "boolean";
185 			case BYTE: return "tinyint";
186 			case SHORT: return "smallint";
187 			case INT: return "int";
188 			case LONG: return "bigint";
189 			case FLOAT: return "float";
190 			case DOUBLE: return "double";
191 			case CHAR: return "char";
192 			case STRING: return "varchar";
193 			case BINARY: return "binary";
194 			case CLOB: return "clob";
195 			case BLOB: return "blob";
196 			case DATE: return "date";
197 			case DATETIME: return "datetime";
198 			case TIME: return "time";
199 		}
200 	}
201 	
202 	protected override void alterTableColumn(string table, InitInfo.Field field, bool typeChanged, bool nullableChanged) {
203 		query("alter table " ~ table ~ " modify column " ~ generateField(field) ~ ";");
204 	}
205 
206 	protected override Result insertInto(string table, string[] names, string[] fields, string[] primaryKeys) {
207 		throw new Exception("Not implemented");
208 	}
209 
210 	// UTILS
211 
212 	protected override string randomFunction() {
213 		return "rand()";
214 	}
215 
216 	protected override string escapeBinary(ubyte[] value) {
217 		return "0x" ~ toHexString(value);
218 	}
219 
220 }
221 
222 alias MysqlDatabaseException = ErrorCodeDatabaseException!("MySQL", ushort);