```go
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
// "bytes"
"database/sql"
"errors"
"flag"
"fmt"
// "io"
// "net"
"runtime"
"strconv"
"strings"
// "sync"
// "sync/atomic"
// "testing"
// "time"
// "github.com/alexbrainman/odbc/api"
)
var (
mssrv = flag.String("mssrv", ".", "ms sql server name")
msdb = flag.String("msdb", "test", "ms sql server database name")
msuser = flag.String("msuser", "test", "ms sql server user name")
mspass = flag.String("mspass", "test", "ms sql server password")
msdriver = flag.String("msdriver", defaultDriver(), "ms sql odbc driver name")
msport = flag.String("msport", "1433", "ms sql server port number")
)
func defaultDriver() string {
if runtime.GOOS == "windows" {
return "sql server"
} else {
return "freetds"
}
}
func isFreeTDS() bool {
return *msdriver == "freetds"
}
type connParams map[string]string
func newConnParams() connParams {
params := connParams{
"driver": *msdriver,
"server": *mssrv,
"database": *msdb,
}
if isFreeTDS() {
params["uid"] = *msuser
params["pwd"] = *mspass
params["port"] = *msport
//params["clientcharset"] = "UTF-8"
//params["debugflags"] = "0xffff"
} else {
if len(*msuser) == 0 {
params["trusted_connection"] = "yes"
} else {
params["uid"] = *msuser
params["pwd"] = *mspass
}
}
a := strings.SplitN(params["server"], ",", -1)
if len(a) == 2 {
params["server"] = a[0]
params["port"] = a[1]
}
return params
}
func (params connParams) getConnAddress() (string, error) {
port, ok := params["port"]
if !ok {
return "", errors.New("no port number provided.")
}
host, ok := params["server"]
if !ok {
return "", errors.New("no host name provided.")
}
return host + ":" + port, nil
}
func (params connParams) updateConnAddress(address string) error {
a := strings.SplitN(address, ":", -1)
if len(a) != 2 {
fmt.Errorf("listen address must have 2 fields, but %d found", len(a))
}
params["server"] = a[0]
params["port"] = a[1]
return nil
}
func (params connParams) makeODBCConnectionString() string {
if port, ok := params["port"]; ok {
params["server"] += "," + port
delete(params, "port")
}
var c string
for n, v := range params {
c += n + "=" + v + ";"
}
return c
}
func mssqlConnectWithParams(params connParams) (db *sql.DB, stmtCount int, err error) {
db, err = sql.Open("odbc", params.makeODBCConnectionString())
if err != nil {
return nil, 0, err
}
return
// return db, stmtCount, nil
// return db, db.Driver().(*Driver).Stats.StmtCount, nil
}
func mssqlConnect() (db *sql.DB, stmtCount int, err error) {
return mssqlConnectWithParams(newConnParams())
}
func closeDB(db *sql.DB, shouldStmtCount, ignoreIfStmtCount int) {
// s := shouldStmtCount
// s := db.Driver().(*Driver).Stats
err := db.Close()
if err != nil {
fmt.Println(err)
return
}
// switch s.StmtCount {
// case shouldStmtCount:
// all good
// case ignoreIfStmtCount:
//t.Logf("ignoring unexpected StmtCount of %v", ignoreIfStmtCount)
// default:
//t.Errorf("unexpected StmtCount: should=%v, is=%v", ignoreIfStmtCount, s.StmtCount)
// }
}
// as per http://www.mssqltips.com/sqlservertip/2198/determine-which-version-of-sql-server-data-access-driver-is-used-by-an-application/
func connProtoVersion(db *sql.DB) ([]byte, error) {
var p []byte
if err := db.QueryRow("select cast(protocol_version as binary(4)) from master.sys.dm_exec_connections where session_id = @@spid").Scan(&p); err != nil {
return nil, err
}
if len(p) != 4 {
return nil, errors.New("failed to fetch connection protocol")
}
return p, nil
}
// as per http://msdn.microsoft.com/en-us/library/dd339982.aspx
func isProto2008OrLater(db *sql.DB) (bool, error) {
p, err := connProtoVersion(db)
if err != nil {
return false, err
}
return p[0] >= 0x73, nil
}
// as per http://www.mssqltips.com/sqlservertip/2563/understanding-the-sql-server-select-version-command/
func serverVersion(db *sql.DB) (sqlVersion, sqlPartNumber, osVersion string, err error) {
var v string
if err = db.QueryRow("select @@version").Scan(&v); err != nil {
return "", "", "", err
}
a := strings.SplitN(v, "\n", -1)
if len(a) < 4 {
return "", "", "", errors.New("SQL Server version string must have at least 4 lines: " + v)
}
for i := range a {
a[i] = strings.Trim(a[i], " \t")
}
l1 := strings.SplitN(a[0], "-", -1)
if len(l1) != 2 {
return "", "", "", errors.New("SQL Server version first line must have - in it: " + v)
}
i := strings.Index(a[3], " on ")
if i < 0 {
return "", "", "", errors.New("SQL Server version fourth line must have 'on' in it: " + v)
}
sqlVersion = l1[0] + a[3][:i]
osVersion = a[3][i+4:]
sqlPartNumber = strings.Trim(l1[1], " ")
l12 := strings.SplitN(sqlPartNumber, " ", -1)
if len(l12) < 2 {
return "", "", "", errors.New("SQL Server version first line must have space after part number in it: " + v)
}
sqlPartNumber = l12[0]
return sqlVersion, sqlPartNumber, osVersion, nil
}
// as per http://www.mssqltips.com/sqlservertip/2563/understanding-the-sql-server-select-version-command/
func isSrv2008OrLater(db *sql.DB) (bool, error) {
_, sqlPartNumber, _, err := serverVersion(db)
if err != nil {
return false, err
}
a := strings.SplitN(sqlPartNumber, ".", -1)
if len(a) != 4 {
return false, errors.New("SQL Server part number must have 4 numbers in it: " + sqlPartNumber)
}
n, err := strconv.ParseInt(a[0], 10, 0)
if err != nil {
return false, errors.New("SQL Server invalid part number: " + sqlPartNumber)
}
return n >= 10, nil
}
func is2008OrLater(db *sql.DB) bool {
b, err := isSrv2008OrLater(db)
if err != nil || !b {
return false
}
b, err = isProto2008OrLater(db)
if err != nil || !b {
return false
}
return true
}
func main() {
}
```
有疑问加站长微信联系(非本文作者)