You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

582 lines
13 KiB

// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
)
// interface to implement to receive the host information
type SetHosts interface {
SetHosts(hosts []*HostInfo)
}
// interface to implement to receive the partitioner value
type SetPartitioner interface {
SetPartitioner(partitioner string)
}
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
if sslOpts.Config == nil {
sslOpts.Config = &tls.Config{}
}
// ca cert is optional
if sslOpts.CaPath != "" {
if sslOpts.RootCAs == nil {
sslOpts.RootCAs = x509.NewCertPool()
}
pem, err := ioutil.ReadFile(sslOpts.CaPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
}
if !sslOpts.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("connectionpool: failed parsing or CA certs")
}
}
if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
}
sslOpts.Certificates = append(sslOpts.Certificates, mycert)
}
sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification
// return clone to avoid race
return sslOpts.Config.Clone(), nil
}
type policyConnPool struct {
session *Session
port int
numConns int
keyspace string
mu sync.RWMutex
hostConnPools map[string]*hostConnPool
endpoints []string
}
func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
var (
err error
tlsConfig *tls.Config
)
// TODO(zariel): move tls config setup into session init.
if cfg.SslOpts != nil {
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
}
}
return &ConnConfig{
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
ConnectTimeout: cfg.ConnectTimeout,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
tlsConfig: tlsConfig,
disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS.
}, nil
}
func newPolicyConnPool(session *Session) *policyConnPool {
// create the pool
pool := &policyConnPool{
session: session,
port: session.cfg.Port,
numConns: session.cfg.NumConns,
keyspace: session.cfg.Keyspace,
hostConnPools: map[string]*hostConnPool{},
}
pool.endpoints = make([]string, len(session.cfg.Hosts))
copy(pool.endpoints, session.cfg.Hosts)
return pool
}
func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
p.mu.Lock()
defer p.mu.Unlock()
toRemove := make(map[string]struct{})
for addr := range p.hostConnPools {
toRemove[addr] = struct{}{}
}
pools := make(chan *hostConnPool)
createCount := 0
for _, host := range hosts {
if !host.IsUp() {
// don't create a connection pool for a down host
continue
}
ip := host.ConnectAddress().String()
if _, exists := p.hostConnPools[ip]; exists {
// still have this host, so don't remove it
delete(toRemove, ip)
continue
}
createCount++
go func(host *HostInfo) {
// create a connection pool for the host
pools <- newHostConnPool(
p.session,
host,
p.port,
p.numConns,
p.keyspace,
)
}(host)
}
// add created pools
for createCount > 0 {
pool := <-pools
createCount--
if pool.Size() > 0 {
// add pool only if there a connections available
p.hostConnPools[string(pool.host.ConnectAddress())] = pool
}
}
for addr := range toRemove {
pool := p.hostConnPools[addr]
delete(p.hostConnPools, addr)
go pool.Close()
}
}
func (p *policyConnPool) Size() int {
p.mu.RLock()
count := 0
for _, pool := range p.hostConnPools {
count += pool.Size()
}
p.mu.RUnlock()
return count
}
func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
ip := host.ConnectAddress().String()
p.mu.RLock()
pool, ok = p.hostConnPools[ip]
p.mu.RUnlock()
return
}
func (p *policyConnPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
// close the pools
for addr, pool := range p.hostConnPools {
delete(p.hostConnPools, addr)
pool.Close()
}
}
func (p *policyConnPool) addHost(host *HostInfo) {
ip := host.ConnectAddress().String()
p.mu.Lock()
pool, ok := p.hostConnPools[ip]
if !ok {
pool = newHostConnPool(
p.session,
host,
host.Port(), // TODO: if port == 0 use pool.port?
p.numConns,
p.keyspace,
)
p.hostConnPools[ip] = pool
}
p.mu.Unlock()
pool.fill()
}
func (p *policyConnPool) removeHost(ip net.IP) {
k := ip.String()
p.mu.Lock()
pool, ok := p.hostConnPools[k]
if !ok {
p.mu.Unlock()
return
}
delete(p.hostConnPools, k)
p.mu.Unlock()
go pool.Close()
}
func (p *policyConnPool) hostUp(host *HostInfo) {
// TODO(zariel): have a set of up hosts and down hosts, we can internally
// detect down hosts, then try to reconnect to them.
p.addHost(host)
}
func (p *policyConnPool) hostDown(ip net.IP) {
// TODO(zariel): mark host as down so we can try to connect to it later, for
// now just treat it has removed.
p.removeHost(ip)
}
// hostConnPool is a connection pool for a single host.
// Connection selection is based on a provided ConnSelectionPolicy
type hostConnPool struct {
session *Session
host *HostInfo
port int
addr string
size int
keyspace string
// protection for conns, closed, filling
mu sync.RWMutex
conns []*Conn
closed bool
filling bool
pos uint32
}
func (h *hostConnPool) String() string {
h.mu.RLock()
defer h.mu.RUnlock()
return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
h.filling, h.closed, len(h.conns), h.size, h.host)
}
func newHostConnPool(session *Session, host *HostInfo, port, size int,
keyspace string) *hostConnPool {
pool := &hostConnPool{
session: session,
host: host,
port: port,
addr: (&net.TCPAddr{IP: host.ConnectAddress(), Port: host.Port()}).String(),
size: size,
keyspace: keyspace,
conns: make([]*Conn, 0, size),
filling: false,
closed: false,
}
// the pool is not filled or connected
return pool
}
// Pick a connection from this connection pool for the given query.
func (pool *hostConnPool) Pick() *Conn {
pool.mu.RLock()
defer pool.mu.RUnlock()
if pool.closed {
return nil
}
size := len(pool.conns)
if size < pool.size {
// try to fill the pool
go pool.fill()
if size == 0 {
return nil
}
}
pos := int(atomic.AddUint32(&pool.pos, 1) - 1)
var (
leastBusyConn *Conn
streamsAvailable int
)
// find the conn which has the most available streams, this is racy
for i := 0; i < size; i++ {
conn := pool.conns[(pos+i)%size]
if streams := conn.AvailableStreams(); streams > streamsAvailable {
leastBusyConn = conn
streamsAvailable = streams
}
}
return leastBusyConn
}
//Size returns the number of connections currently active in the pool
func (pool *hostConnPool) Size() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
return len(pool.conns)
}
//Close the connection pool
func (pool *hostConnPool) Close() {
pool.mu.Lock()
if pool.closed {
pool.mu.Unlock()
return
}
pool.closed = true
// ensure we dont try to reacquire the lock in handleError
// TODO: improve this as the following can happen
// 1) we have locked pool.mu write lock
// 2) conn.Close calls conn.closeWithError(nil)
// 3) conn.closeWithError calls conn.Close() which returns an error
// 4) conn.closeWithError calls pool.HandleError with the error from conn.Close
// 5) pool.HandleError tries to lock pool.mu
// deadlock
// empty the pool
conns := pool.conns
pool.conns = nil
pool.mu.Unlock()
// close the connections
for _, conn := range conns {
conn.Close()
}
}
// Fill the connection pool
func (pool *hostConnPool) fill() {
pool.mu.RLock()
// avoid filling a closed pool, or concurrent filling
if pool.closed || pool.filling {
pool.mu.RUnlock()
return
}
// determine the filling work to be done
startCount := len(pool.conns)
fillCount := pool.size - startCount
// avoid filling a full (or overfull) pool
if fillCount <= 0 {
pool.mu.RUnlock()
return
}
// switch from read to write lock
pool.mu.RUnlock()
pool.mu.Lock()
// double check everything since the lock was released
startCount = len(pool.conns)
fillCount = pool.size - startCount
if pool.closed || pool.filling || fillCount <= 0 {
// looks like another goroutine already beat this
// goroutine to the filling
pool.mu.Unlock()
return
}
// ok fill the pool
pool.filling = true
// allow others to access the pool while filling
pool.mu.Unlock()
// only this goroutine should make calls to fill/empty the pool at this
// point until after this routine or its subordinates calls
// fillingStopped
// fill only the first connection synchronously
if startCount == 0 {
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
// probably unreachable host
pool.fillingStopped(true)
// this is call with the connection pool mutex held, this call will
// then recursively try to lock it again. FIXME
if pool.session.cfg.ConvictionPolicy.AddFailure(err, pool.host) {
go pool.session.handleNodeDown(pool.host.ConnectAddress(), pool.port)
}
return
}
// filled one
fillCount--
}
// fill the rest of the pool asynchronously
go func() {
err := pool.connectMany(fillCount)
// mark the end of filling
pool.fillingStopped(err != nil)
}()
}
func (pool *hostConnPool) logConnectErr(err error) {
if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
// connection refused
// these are typical during a node outage so avoid log spam.
if gocqlDebug {
Logger.Printf("unable to dial %q: %v\n", pool.host.ConnectAddress(), err)
}
} else if err != nil {
// unexpected error
Logger.Printf("error: failed to connect to %s due to error: %v", pool.addr, err)
}
}
// transition back to a not-filling state.
func (pool *hostConnPool) fillingStopped(hadError bool) {
if hadError {
// wait for some time to avoid back-to-back filling
// this provides some time between failed attempts
// to fill the pool for the host to recover
time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
}
pool.mu.Lock()
pool.filling = false
pool.mu.Unlock()
}
// connectMany creates new connections concurrent.
func (pool *hostConnPool) connectMany(count int) error {
if count == 0 {
return nil
}
var (
wg sync.WaitGroup
mu sync.Mutex
connectErr error
)
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
mu.Lock()
connectErr = err
mu.Unlock()
}
}()
}
// wait for all connections are done
wg.Wait()
return connectErr
}
// create a new connection to the host and add it to the pool
func (pool *hostConnPool) connect() (err error) {
// TODO: provide a more robust connection retry mechanism, we should also
// be able to detect hosts that come up by trying to connect to downed ones.
// try to connect
var conn *Conn
reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
conn, err = pool.session.connect(pool.session.ctx, pool.host, pool)
if err == nil {
break
}
if opErr, isOpErr := err.(*net.OpError); isOpErr {
// if the error is not a temporary error (ex: network unreachable) don't
// retry
if !opErr.Temporary() {
break
}
}
if gocqlDebug {
Logger.Printf("connection failed %q: %v, reconnecting with %T\n",
pool.host.ConnectAddress(), err, reconnectionPolicy)
}
time.Sleep(reconnectionPolicy.GetInterval(i))
}
if err != nil {
return err
}
if pool.keyspace != "" {
// set the keyspace
if err = conn.UseKeyspace(pool.keyspace); err != nil {
conn.Close()
return err
}
}
// add the Conn to the pool
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
conn.Close()
return nil
}
pool.conns = append(pool.conns, conn)
return nil
}
// handle any error from a Conn
func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
if !closed {
// still an open connection, so continue using it
return
}
// TODO: track the number of errors per host and detect when a host is dead,
// then also have something which can detect when a host comes back.
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
// pool closed
return
}
// find the connection index
for i, candidate := range pool.conns {
if candidate == conn {
// remove the connection, not preserving order
pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
// lost a connection, so fill the pool
go pool.fill()
break
}
}
}