503 lines
15 KiB
Go
503 lines
15 KiB
Go
package mssql
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/context"
|
|
)
|
|
|
|
type MockTransport struct {
|
|
bytes.Buffer
|
|
}
|
|
|
|
func (t *MockTransport) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func TestSendLogin(t *testing.T) {
|
|
memBuf := new(MockTransport)
|
|
buf := newTdsBuffer(1024, memBuf)
|
|
login := login{
|
|
TDSVersion: verTDS73,
|
|
PacketSize: 0x1000,
|
|
ClientProgVer: 0x01060100,
|
|
ClientPID: 100,
|
|
ClientTimeZone: -4 * 60,
|
|
ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab},
|
|
OptionFlags1: 0xe0,
|
|
OptionFlags3: 8,
|
|
HostName: "subdev1",
|
|
UserName: "test",
|
|
Password: "testpwd",
|
|
AppName: "appname",
|
|
ServerName: "servername",
|
|
CtlIntName: "library",
|
|
Language: "en",
|
|
Database: "database",
|
|
ClientLCID: 0x204,
|
|
AtchDBFile: "filepath",
|
|
}
|
|
err := sendLogin(buf, login)
|
|
if err != nil {
|
|
t.Error("sendLogin should succeed")
|
|
}
|
|
ref := []byte{
|
|
16, 1, 0, 222, 0, 0, 1, 0, 198 + 16, 0, 0, 0, 3, 0, 10, 115, 0, 16, 0, 0, 0, 1,
|
|
6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 8, 16, 255, 255, 255, 4, 2, 0,
|
|
0, 94, 0, 7, 0, 108, 0, 4, 0, 116, 0, 7, 0, 130, 0, 7, 0, 144, 0, 10, 0, 0,
|
|
0, 0, 0, 164, 0, 7, 0, 178, 0, 2, 0, 182, 0, 8, 0, 18, 52, 86, 120, 144, 171,
|
|
198, 0, 0, 0, 198, 0, 8, 0, 214, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98,
|
|
0, 100, 0, 101, 0, 118, 0, 49, 0, 116, 0, 101, 0, 115, 0, 116, 0, 226, 165,
|
|
243, 165, 146, 165, 226, 165, 162, 165, 210, 165, 227, 165, 97, 0, 112,
|
|
0, 112, 0, 110, 0, 97, 0, 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0,
|
|
101, 0, 114, 0, 110, 0, 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114,
|
|
0, 97, 0, 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98,
|
|
0, 97, 0, 115, 0, 101, 0, 102, 0, 105, 0, 108, 0, 101, 0, 112, 0, 97, 0,
|
|
116, 0, 104, 0}
|
|
out := memBuf.Bytes()
|
|
if !bytes.Equal(ref, out) {
|
|
fmt.Println("Expected:")
|
|
fmt.Print(hex.Dump(ref))
|
|
fmt.Println("Returned:")
|
|
fmt.Print(hex.Dump(out))
|
|
t.Error("input output don't match")
|
|
}
|
|
}
|
|
|
|
func TestSendSqlBatch(t *testing.T) {
|
|
checkConnStr(t)
|
|
p, err := parseConnectParams(makeConnStr(t).String())
|
|
if err != nil {
|
|
t.Error("parseConnectParams failed:", err.Error())
|
|
return
|
|
}
|
|
|
|
conn, err := connect(optionalLogger{testLogger{t}}, p)
|
|
if err != nil {
|
|
t.Error("Open connection failed:", err.Error())
|
|
return
|
|
}
|
|
defer conn.buf.transport.Close()
|
|
|
|
headers := []headerStruct{
|
|
{hdrtype: dataStmHdrTransDescr,
|
|
data: transDescrHdr{0, 1}.pack()},
|
|
}
|
|
err = sendSqlBatch72(conn.buf, "select 1", headers)
|
|
if err != nil {
|
|
t.Error("Sending sql batch failed", err.Error())
|
|
return
|
|
}
|
|
|
|
ch := make(chan tokenStruct, 5)
|
|
go processResponse(context.Background(), conn, ch)
|
|
|
|
var lastRow []interface{}
|
|
loop:
|
|
for tok := range ch {
|
|
switch token := tok.(type) {
|
|
case doneStruct:
|
|
break loop
|
|
case []columnStruct:
|
|
conn.columns = token
|
|
case []interface{}:
|
|
lastRow = token
|
|
default:
|
|
fmt.Println("unknown token", tok)
|
|
}
|
|
}
|
|
|
|
if len(lastRow) == 0 {
|
|
t.Fatal("expected row but no row set")
|
|
}
|
|
|
|
switch value := lastRow[0].(type) {
|
|
case int32:
|
|
if value != 1 {
|
|
t.Error("Invalid value returned, should be 1", value)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkConnStr(t *testing.T) {
|
|
if len(os.Getenv("SQLSERVER_DSN")) > 0 {
|
|
return
|
|
}
|
|
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 {
|
|
return
|
|
}
|
|
t.Skip("no database connection string")
|
|
}
|
|
|
|
// makeConnStr returns a URL struct so it may be modified by various
|
|
// tests before used as a DSN.
|
|
func makeConnStr(t *testing.T) *url.URL {
|
|
dsn := os.Getenv("SQLSERVER_DSN")
|
|
if len(dsn) > 0 {
|
|
parsed, err := url.Parse(dsn)
|
|
if err != nil {
|
|
t.Fatal("unable to parse SQLSERVER_DSN as URL", err)
|
|
}
|
|
values := parsed.Query()
|
|
values.Set("log", "127")
|
|
parsed.RawQuery = values.Encode()
|
|
return parsed
|
|
}
|
|
values := url.Values{}
|
|
values.Set("log", "127")
|
|
values.Set("database", os.Getenv("DATABASE"))
|
|
return &url.URL{
|
|
Scheme: "sqlserver",
|
|
Host: os.Getenv("HOST"),
|
|
Path: os.Getenv("INSTANCE"),
|
|
User: url.UserPassword(os.Getenv("SQLUSER"), os.Getenv("SQLPASSWORD")),
|
|
RawQuery: values.Encode(),
|
|
}
|
|
}
|
|
|
|
type testLogger struct {
|
|
t *testing.T
|
|
}
|
|
|
|
func (l testLogger) Printf(format string, v ...interface{}) {
|
|
l.t.Logf(format, v...)
|
|
}
|
|
|
|
func (l testLogger) Println(v ...interface{}) {
|
|
l.t.Log(v...)
|
|
}
|
|
|
|
func open(t *testing.T) *sql.DB {
|
|
checkConnStr(t)
|
|
SetLogger(testLogger{t})
|
|
conn, err := sql.Open("mssql", makeConnStr(t).String())
|
|
if err != nil {
|
|
t.Error("Open connection failed:", err.Error())
|
|
return nil
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func TestConnect(t *testing.T) {
|
|
checkConnStr(t)
|
|
SetLogger(testLogger{t})
|
|
conn, err := sql.Open("mssql", makeConnStr(t).String())
|
|
if err != nil {
|
|
t.Error("Open connection failed:", err.Error())
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
}
|
|
|
|
func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) {
|
|
stmt, err := conn.Prepare("select 1 as a")
|
|
if err != nil {
|
|
t.Error("Prepare failed:", err.Error())
|
|
return nil
|
|
}
|
|
return stmt
|
|
}
|
|
|
|
func checkSimpleQuery(rows *sql.Rows, t *testing.T) {
|
|
numrows := 0
|
|
for rows.Next() {
|
|
var val int
|
|
err := rows.Scan(&val)
|
|
if err != nil {
|
|
t.Error("Scan failed:", err.Error())
|
|
}
|
|
if val != 1 {
|
|
t.Error("query should return 1")
|
|
}
|
|
numrows++
|
|
}
|
|
if numrows != 1 {
|
|
t.Error("query should return 1 row, returned", numrows)
|
|
}
|
|
}
|
|
|
|
func TestQuery(t *testing.T) {
|
|
conn := open(t)
|
|
if conn == nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
stmt := simpleQuery(conn, t)
|
|
if stmt == nil {
|
|
return
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
}
|
|
defer rows.Close()
|
|
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
t.Error("getting columns failed", err.Error())
|
|
}
|
|
if len(columns) != 1 && columns[0] != "a" {
|
|
t.Error("returned incorrect columns (expected ['a']):", columns)
|
|
}
|
|
|
|
checkSimpleQuery(rows, t)
|
|
}
|
|
|
|
func TestMultipleQueriesSequentialy(t *testing.T) {
|
|
|
|
conn := open(t)
|
|
defer conn.Close()
|
|
|
|
stmt, err := conn.Prepare("select 1 as a")
|
|
if err != nil {
|
|
t.Error("Prepare failed:", err.Error())
|
|
return
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
checkSimpleQuery(rows, t)
|
|
|
|
rows, err = stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
checkSimpleQuery(rows, t)
|
|
}
|
|
|
|
func TestMultipleQueryClose(t *testing.T) {
|
|
conn := open(t)
|
|
defer conn.Close()
|
|
|
|
stmt, err := conn.Prepare("select 1 as a")
|
|
if err != nil {
|
|
t.Error("Prepare failed:", err.Error())
|
|
return
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
rows.Close()
|
|
|
|
rows, err = stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
checkSimpleQuery(rows, t)
|
|
}
|
|
|
|
func TestPing(t *testing.T) {
|
|
conn := open(t)
|
|
defer conn.Close()
|
|
conn.Ping()
|
|
}
|
|
|
|
func TestSecureWithInvalidHostName(t *testing.T) {
|
|
checkConnStr(t)
|
|
SetLogger(testLogger{t})
|
|
|
|
dsn := makeConnStr(t)
|
|
dsnParams := dsn.Query()
|
|
dsnParams.Set("encrypt", "true")
|
|
dsnParams.Set("TrustServerCertificate", "false")
|
|
dsnParams.Set("hostNameInCertificate", "foo.bar")
|
|
dsn.RawQuery = dsnParams.Encode()
|
|
|
|
conn, err := sql.Open("mssql", dsn.String())
|
|
if err != nil {
|
|
t.Fatal("Open connection failed:", err.Error())
|
|
}
|
|
defer conn.Close()
|
|
err = conn.Ping()
|
|
if err == nil {
|
|
t.Fatal("Connected to fake foo.bar server")
|
|
}
|
|
}
|
|
|
|
func TestSecureConnection(t *testing.T) {
|
|
checkConnStr(t)
|
|
SetLogger(testLogger{t})
|
|
|
|
dsn := makeConnStr(t)
|
|
dsnParams := dsn.Query()
|
|
dsnParams.Set("encrypt", "true")
|
|
dsnParams.Set("TrustServerCertificate", "true")
|
|
dsn.RawQuery = dsnParams.Encode()
|
|
|
|
conn, err := sql.Open("mssql", dsn.String())
|
|
if err != nil {
|
|
t.Fatal("Open connection failed:", err.Error())
|
|
}
|
|
defer conn.Close()
|
|
var msg string
|
|
err = conn.QueryRow("select 'secret'").Scan(&msg)
|
|
if err != nil {
|
|
t.Fatal("cannot scan value", err)
|
|
}
|
|
if msg != "secret" {
|
|
t.Fatal("expected secret, got: ", msg)
|
|
}
|
|
var secure bool
|
|
err = conn.QueryRow("select encrypt_option from sys.dm_exec_connections where session_id=@@SPID").Scan(&secure)
|
|
if err != nil {
|
|
t.Fatal("cannot scan value", err)
|
|
}
|
|
if !secure {
|
|
t.Fatal("connection is not encrypted")
|
|
}
|
|
}
|
|
|
|
func TestInvalidConnectionString(t *testing.T) {
|
|
connStrings := []string{
|
|
"log=invalid",
|
|
"port=invalid",
|
|
"packet size=invalid",
|
|
"connection timeout=invalid",
|
|
"dial timeout=invalid",
|
|
"keepalive=invalid",
|
|
"encrypt=invalid",
|
|
"trustservercertificate=invalid",
|
|
"failoverport=invalid",
|
|
|
|
// ODBC mode
|
|
"odbc:password={",
|
|
"odbc:password={somepass",
|
|
"odbc:password={somepass}}",
|
|
"odbc:password={some}pass",
|
|
}
|
|
for _, connStr := range connStrings {
|
|
_, err := parseConnectParams(connStr)
|
|
if err == nil {
|
|
t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr)
|
|
continue
|
|
} else {
|
|
t.Logf("Connection failed for %s as expected with error %v", connStr, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestValidConnectionString(t *testing.T) {
|
|
type testStruct struct {
|
|
connStr string
|
|
check func(connectParams) bool
|
|
}
|
|
connStrings := []testStruct{
|
|
{"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p connectParams) bool {
|
|
return p.host == "server" && p.instance == "instance" && p.user == "tester" && p.password == "pwd"
|
|
}},
|
|
{"server=.", func(p connectParams) bool { return p.host == "localhost" }},
|
|
{"server=(local)", func(p connectParams) bool { return p.host == "localhost" }},
|
|
{"ServerSPN=serverspn;Workstation ID=workstid", func(p connectParams) bool { return p.serverSPN == "serverspn" && p.workstation == "workstid" }},
|
|
{"failoverpartner=fopartner;failoverport=2000", func(p connectParams) bool { return p.failOverPartner == "fopartner" && p.failOverPort == 2000 }},
|
|
{"app name=appname;applicationintent=ReadOnly", func(p connectParams) bool { return p.appname == "appname" && (p.typeFlags&fReadOnlyIntent != 0) }},
|
|
{"encrypt=disable", func(p connectParams) bool { return p.disableEncryption }},
|
|
{"encrypt=true", func(p connectParams) bool { return p.encrypt && !p.disableEncryption }},
|
|
{"encrypt=false", func(p connectParams) bool { return !p.encrypt && !p.disableEncryption }},
|
|
{"trustservercertificate=true", func(p connectParams) bool { return p.trustServerCertificate }},
|
|
{"trustservercertificate=false", func(p connectParams) bool { return !p.trustServerCertificate }},
|
|
{"certificate=abc", func(p connectParams) bool { return p.certificate == "abc" }},
|
|
{"hostnameincertificate=abc", func(p connectParams) bool { return p.hostInCertificate == "abc" }},
|
|
{"connection timeout=3;dial timeout=4;keepalive=5", func(p connectParams) bool {
|
|
return p.conn_timeout == 3*time.Second && p.dial_timeout == 4*time.Second && p.keepAlive == 5*time.Second
|
|
}},
|
|
{"log=63", func(p connectParams) bool { return p.logFlags == 63 && p.port == 1433 }},
|
|
{"log=63;port=1000", func(p connectParams) bool { return p.logFlags == 63 && p.port == 1000 }},
|
|
{"log=64", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 4096 }},
|
|
{"log=64;packet size=0", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 512 }},
|
|
{"log=64;packet size=300", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 512 }},
|
|
{"log=64;packet size=8192", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 8192 }},
|
|
{"log=64;packet size=48000", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 32767 }},
|
|
|
|
// those are supported currently, but maybe should not be
|
|
{"someparam", func(p connectParams) bool { return true }},
|
|
{";;=;", func(p connectParams) bool { return true }},
|
|
|
|
// ODBC mode
|
|
{"odbc:server=somehost;user id=someuser;password=somepass", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "somepass"
|
|
}},
|
|
{"odbc:server=somehost;user id=someuser;password=some{pass", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "some{pass"
|
|
}},
|
|
{"odbc:server={somehost};user id={someuser};password={somepass}", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "somepass"
|
|
}},
|
|
{"odbc:server={somehost};user id={someuser};password={some=pass}", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "some=pass"
|
|
}},
|
|
{"odbc:server={somehost};user id={someuser};password={some;pass}", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "some;pass"
|
|
}},
|
|
{"odbc:server={somehost};user id={someuser};password={some{pass}", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "some{pass"
|
|
}},
|
|
{"odbc:server={somehost};user id={someuser};password={some}}pass}", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "some}pass"
|
|
}},
|
|
{"odbc:server={somehost};user id={someuser};password={some{}}p=a;ss}", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "some{}p=a;ss"
|
|
}},
|
|
{"odbc: server = somehost; user id = someuser ; password = {some pass } ", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.user == "someuser" && p.password == "some pass "
|
|
}},
|
|
|
|
// URL mode
|
|
{"sqlserver://somehost?connection+timeout=30", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.conn_timeout == 30*time.Second
|
|
}},
|
|
{"sqlserver://someuser@somehost?connection+timeout=30", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "" && p.conn_timeout == 30*time.Second
|
|
}},
|
|
{"sqlserver://someuser:@somehost?connection+timeout=30", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "" && p.conn_timeout == 30*time.Second
|
|
}},
|
|
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost?connection+timeout=30", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
|
|
}},
|
|
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434?connection+timeout=30", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.port == 1434 && p.instance == "" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
|
|
}},
|
|
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434/someinstance?connection+timeout=30", func(p connectParams) bool {
|
|
return p.host == "somehost" && p.port == 1434 && p.instance == "someinstance" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
|
|
}},
|
|
}
|
|
for _, ts := range connStrings {
|
|
p, err := parseConnectParams(ts.connStr)
|
|
if err == nil {
|
|
t.Logf("Connection string was parsed successfully %s", ts.connStr)
|
|
} else {
|
|
t.Errorf("Connection string %s failed to parse with error %s", ts.connStr, err)
|
|
continue
|
|
}
|
|
|
|
if !ts.check(p) {
|
|
t.Errorf("Check failed on conn str %s", ts.connStr)
|
|
}
|
|
}
|
|
}
|