Konfi-Castle-Kasino/vendor/github.com/denisenkom/go-mssqldb/tds_test.go

503 lines
15 KiB
Go

package mssql
import (
"bytes"
"database/sql"
"encoding/hex"
"fmt"
"net/url"
"os"
"testing"
"time"
"golang.org/x/net/context"
)
type MockTransport struct {
bytes.Buffer
}
func (t *MockTransport) Close() error {
return nil
}
func TestSendLogin(t *testing.T) {
memBuf := new(MockTransport)
buf := newTdsBuffer(1024, memBuf)
login := login{
TDSVersion: verTDS73,
PacketSize: 0x1000,
ClientProgVer: 0x01060100,
ClientPID: 100,
ClientTimeZone: -4 * 60,
ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab},
OptionFlags1: 0xe0,
OptionFlags3: 8,
HostName: "subdev1",
UserName: "test",
Password: "testpwd",
AppName: "appname",
ServerName: "servername",
CtlIntName: "library",
Language: "en",
Database: "database",
ClientLCID: 0x204,
AtchDBFile: "filepath",
}
err := sendLogin(buf, login)
if err != nil {
t.Error("sendLogin should succeed")
}
ref := []byte{
16, 1, 0, 222, 0, 0, 1, 0, 198 + 16, 0, 0, 0, 3, 0, 10, 115, 0, 16, 0, 0, 0, 1,
6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 8, 16, 255, 255, 255, 4, 2, 0,
0, 94, 0, 7, 0, 108, 0, 4, 0, 116, 0, 7, 0, 130, 0, 7, 0, 144, 0, 10, 0, 0,
0, 0, 0, 164, 0, 7, 0, 178, 0, 2, 0, 182, 0, 8, 0, 18, 52, 86, 120, 144, 171,
198, 0, 0, 0, 198, 0, 8, 0, 214, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98,
0, 100, 0, 101, 0, 118, 0, 49, 0, 116, 0, 101, 0, 115, 0, 116, 0, 226, 165,
243, 165, 146, 165, 226, 165, 162, 165, 210, 165, 227, 165, 97, 0, 112,
0, 112, 0, 110, 0, 97, 0, 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0,
101, 0, 114, 0, 110, 0, 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114,
0, 97, 0, 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98,
0, 97, 0, 115, 0, 101, 0, 102, 0, 105, 0, 108, 0, 101, 0, 112, 0, 97, 0,
116, 0, 104, 0}
out := memBuf.Bytes()
if !bytes.Equal(ref, out) {
fmt.Println("Expected:")
fmt.Print(hex.Dump(ref))
fmt.Println("Returned:")
fmt.Print(hex.Dump(out))
t.Error("input output don't match")
}
}
func TestSendSqlBatch(t *testing.T) {
checkConnStr(t)
p, err := parseConnectParams(makeConnStr(t).String())
if err != nil {
t.Error("parseConnectParams failed:", err.Error())
return
}
conn, err := connect(optionalLogger{testLogger{t}}, p)
if err != nil {
t.Error("Open connection failed:", err.Error())
return
}
defer conn.buf.transport.Close()
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{0, 1}.pack()},
}
err = sendSqlBatch72(conn.buf, "select 1", headers)
if err != nil {
t.Error("Sending sql batch failed", err.Error())
return
}
ch := make(chan tokenStruct, 5)
go processResponse(context.Background(), conn, ch)
var lastRow []interface{}
loop:
for tok := range ch {
switch token := tok.(type) {
case doneStruct:
break loop
case []columnStruct:
conn.columns = token
case []interface{}:
lastRow = token
default:
fmt.Println("unknown token", tok)
}
}
if len(lastRow) == 0 {
t.Fatal("expected row but no row set")
}
switch value := lastRow[0].(type) {
case int32:
if value != 1 {
t.Error("Invalid value returned, should be 1", value)
return
}
}
}
func checkConnStr(t *testing.T) {
if len(os.Getenv("SQLSERVER_DSN")) > 0 {
return
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 {
return
}
t.Skip("no database connection string")
}
// makeConnStr returns a URL struct so it may be modified by various
// tests before used as a DSN.
func makeConnStr(t *testing.T) *url.URL {
dsn := os.Getenv("SQLSERVER_DSN")
if len(dsn) > 0 {
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatal("unable to parse SQLSERVER_DSN as URL", err)
}
values := parsed.Query()
values.Set("log", "127")
parsed.RawQuery = values.Encode()
return parsed
}
values := url.Values{}
values.Set("log", "127")
values.Set("database", os.Getenv("DATABASE"))
return &url.URL{
Scheme: "sqlserver",
Host: os.Getenv("HOST"),
Path: os.Getenv("INSTANCE"),
User: url.UserPassword(os.Getenv("SQLUSER"), os.Getenv("SQLPASSWORD")),
RawQuery: values.Encode(),
}
}
type testLogger struct {
t *testing.T
}
func (l testLogger) Printf(format string, v ...interface{}) {
l.t.Logf(format, v...)
}
func (l testLogger) Println(v ...interface{}) {
l.t.Log(v...)
}
func open(t *testing.T) *sql.DB {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("mssql", makeConnStr(t).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil
}
return conn
}
func TestConnect(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("mssql", makeConnStr(t).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return
}
defer conn.Close()
}
func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) {
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return nil
}
return stmt
}
func checkSimpleQuery(rows *sql.Rows, t *testing.T) {
numrows := 0
for rows.Next() {
var val int
err := rows.Scan(&val)
if err != nil {
t.Error("Scan failed:", err.Error())
}
if val != 1 {
t.Error("query should return 1")
}
numrows++
}
if numrows != 1 {
t.Error("query should return 1 row, returned", numrows)
}
}
func TestQuery(t *testing.T) {
conn := open(t)
if conn == nil {
return
}
defer conn.Close()
stmt := simpleQuery(conn, t)
if stmt == nil {
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
t.Error("getting columns failed", err.Error())
}
if len(columns) != 1 && columns[0] != "a" {
t.Error("returned incorrect columns (expected ['a']):", columns)
}
checkSimpleQuery(rows, t)
}
func TestMultipleQueriesSequentialy(t *testing.T) {
conn := open(t)
defer conn.Close()
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
rows, err = stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
}
func TestMultipleQueryClose(t *testing.T) {
conn := open(t)
defer conn.Close()
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
rows.Close()
rows, err = stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
}
func TestPing(t *testing.T) {
conn := open(t)
defer conn.Close()
conn.Ping()
}
func TestSecureWithInvalidHostName(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
dsn := makeConnStr(t)
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "true")
dsnParams.Set("TrustServerCertificate", "false")
dsnParams.Set("hostNameInCertificate", "foo.bar")
dsn.RawQuery = dsnParams.Encode()
conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
err = conn.Ping()
if err == nil {
t.Fatal("Connected to fake foo.bar server")
}
}
func TestSecureConnection(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
dsn := makeConnStr(t)
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "true")
dsnParams.Set("TrustServerCertificate", "true")
dsn.RawQuery = dsnParams.Encode()
conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
var msg string
err = conn.QueryRow("select 'secret'").Scan(&msg)
if err != nil {
t.Fatal("cannot scan value", err)
}
if msg != "secret" {
t.Fatal("expected secret, got: ", msg)
}
var secure bool
err = conn.QueryRow("select encrypt_option from sys.dm_exec_connections where session_id=@@SPID").Scan(&secure)
if err != nil {
t.Fatal("cannot scan value", err)
}
if !secure {
t.Fatal("connection is not encrypted")
}
}
func TestInvalidConnectionString(t *testing.T) {
connStrings := []string{
"log=invalid",
"port=invalid",
"packet size=invalid",
"connection timeout=invalid",
"dial timeout=invalid",
"keepalive=invalid",
"encrypt=invalid",
"trustservercertificate=invalid",
"failoverport=invalid",
// ODBC mode
"odbc:password={",
"odbc:password={somepass",
"odbc:password={somepass}}",
"odbc:password={some}pass",
}
for _, connStr := range connStrings {
_, err := parseConnectParams(connStr)
if err == nil {
t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr)
continue
} else {
t.Logf("Connection failed for %s as expected with error %v", connStr, err)
}
}
}
func TestValidConnectionString(t *testing.T) {
type testStruct struct {
connStr string
check func(connectParams) bool
}
connStrings := []testStruct{
{"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p connectParams) bool {
return p.host == "server" && p.instance == "instance" && p.user == "tester" && p.password == "pwd"
}},
{"server=.", func(p connectParams) bool { return p.host == "localhost" }},
{"server=(local)", func(p connectParams) bool { return p.host == "localhost" }},
{"ServerSPN=serverspn;Workstation ID=workstid", func(p connectParams) bool { return p.serverSPN == "serverspn" && p.workstation == "workstid" }},
{"failoverpartner=fopartner;failoverport=2000", func(p connectParams) bool { return p.failOverPartner == "fopartner" && p.failOverPort == 2000 }},
{"app name=appname;applicationintent=ReadOnly", func(p connectParams) bool { return p.appname == "appname" && (p.typeFlags&fReadOnlyIntent != 0) }},
{"encrypt=disable", func(p connectParams) bool { return p.disableEncryption }},
{"encrypt=true", func(p connectParams) bool { return p.encrypt && !p.disableEncryption }},
{"encrypt=false", func(p connectParams) bool { return !p.encrypt && !p.disableEncryption }},
{"trustservercertificate=true", func(p connectParams) bool { return p.trustServerCertificate }},
{"trustservercertificate=false", func(p connectParams) bool { return !p.trustServerCertificate }},
{"certificate=abc", func(p connectParams) bool { return p.certificate == "abc" }},
{"hostnameincertificate=abc", func(p connectParams) bool { return p.hostInCertificate == "abc" }},
{"connection timeout=3;dial timeout=4;keepalive=5", func(p connectParams) bool {
return p.conn_timeout == 3*time.Second && p.dial_timeout == 4*time.Second && p.keepAlive == 5*time.Second
}},
{"log=63", func(p connectParams) bool { return p.logFlags == 63 && p.port == 1433 }},
{"log=63;port=1000", func(p connectParams) bool { return p.logFlags == 63 && p.port == 1000 }},
{"log=64", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 4096 }},
{"log=64;packet size=0", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 512 }},
{"log=64;packet size=300", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 512 }},
{"log=64;packet size=8192", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 8192 }},
{"log=64;packet size=48000", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 32767 }},
// those are supported currently, but maybe should not be
{"someparam", func(p connectParams) bool { return true }},
{";;=;", func(p connectParams) bool { return true }},
// ODBC mode
{"odbc:server=somehost;user id=someuser;password=somepass", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "somepass"
}},
{"odbc:server=somehost;user id=someuser;password=some{pass", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some{pass"
}},
{"odbc:server={somehost};user id={someuser};password={somepass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "somepass"
}},
{"odbc:server={somehost};user id={someuser};password={some=pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some=pass"
}},
{"odbc:server={somehost};user id={someuser};password={some;pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some;pass"
}},
{"odbc:server={somehost};user id={someuser};password={some{pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some{pass"
}},
{"odbc:server={somehost};user id={someuser};password={some}}pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some}pass"
}},
{"odbc:server={somehost};user id={someuser};password={some{}}p=a;ss}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some{}p=a;ss"
}},
{"odbc: server = somehost; user id = someuser ; password = {some pass } ", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some pass "
}},
// URL mode
{"sqlserver://somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser@somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:@somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1434 && p.instance == "" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434/someinstance?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1434 && p.instance == "someinstance" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
}},
}
for _, ts := range connStrings {
p, err := parseConnectParams(ts.connStr)
if err == nil {
t.Logf("Connection string was parsed successfully %s", ts.connStr)
} else {
t.Errorf("Connection string %s failed to parse with error %s", ts.connStr, err)
continue
}
if !ts.check(p) {
t.Errorf("Check failed on conn str %s", ts.connStr)
}
}
}