add vendor, change some js file to typescript

pull/40/head
sunface 5 years ago
parent f9a0b0696d
commit 594f806d59

@ -18,8 +18,8 @@ import (
"fmt"
"os"
"github.com/go-rust/im.dev/internal"
"github.com/spf13/cobra"
"github.com/thinkindev/im.dev/internal"
)
var cfgFile string

@ -0,0 +1,16 @@
module github.com/go-rust/im.dev
go 1.13
require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect
github.com/gocql/gocql v0.0.0-20200103014340-68f928edb90a
github.com/labstack/echo v3.3.10+incompatible
github.com/labstack/gommon v0.3.0 // indirect
github.com/microcosm-cc/bluemonday v1.0.2
github.com/sony/sonyflake v1.0.0
github.com/spf13/cobra v0.0.5
github.com/valyala/fasttemplate v1.1.0 // indirect
go.uber.org/zap v1.13.0
gopkg.in/yaml.v2 v2.2.7
)

124
go.sum

@ -0,0 +1,124 @@
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/deckarep/golang-set v1.7.1 h1:SCQV0S6gTtp6itiFrTqI+pfmJ4LN85S1YzhDf9rTHJQ=
github.com/deckarep/golang-set v1.7.1/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/gocql/gocql v0.0.0-20200103014340-68f928edb90a h1:f/7VP2EmdQagG92I/75YnM3ZIeCYa61BT2kZoJFptHM=
github.com/gocql/gocql v0.0.0-20200103014340-68f928edb90a/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/labstack/echo v3.3.10+incompatible h1:pGRcYk231ExFAyoAjAfD85kQzRJCRI8bbnE7CX5OEgg=
github.com/labstack/echo v3.3.10+incompatible/go.mod h1:0INS7j/VjnFxD4E2wkz67b8cVwCLbBmJyDaka6Cmk1s=
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/microcosm-cc/bluemonday v1.0.2 h1:5lPfLTTAvAbtS0VqT+94yOtFnGfUWYyx0+iToC3Os3s=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/sony/sonyflake v1.0.0 h1:MpU6Ro7tfXwgn2l5eluf9xQvQJDROTBImNCfRXn/YeM=
github.com/sony/sonyflake v1.0.0/go.mod h1:Jv3cfhf/UFtolOTTRd3q4Nl6ENqM+KfyZ5PseKfZGF4=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s=
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4=
github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY=
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc=
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU=
go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529 h1:iMGN4xG0cnqj3t+zOM8wUB0BiPKHEwSxEZCvzcbZuvk=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs=
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=

@ -1,9 +1,9 @@
package internal
import (
"github.com/go-rust/im.dev/internal/post"
"github.com/go-rust/im.dev/internal/user"
"github.com/labstack/echo"
"github.com/thinkindev/im.dev/internal/post"
"github.com/thinkindev/im.dev/internal/user"
)
func apiHandler(e *echo.Echo) {

@ -5,11 +5,11 @@ import (
"go.uber.org/zap"
"github.com/go-rust/im.dev/internal/misc"
"github.com/go-rust/im.dev/internal/user"
"github.com/gocql/gocql"
"github.com/labstack/echo"
"github.com/labstack/echo/middleware"
"github.com/thinkindev/im.dev/internal/misc"
"github.com/thinkindev/im.dev/internal/user"
)
// Start web server for im.dev ui
@ -32,8 +32,8 @@ func Start(confPath string) {
user.InitUser()
e := echo.New()
e.Pre(middleware.RemoveTrailingSlash())
e.Use(middleware.GzipWithConfig(middleware.GzipConfig{Level: 5}))
// e.Pre(middleware.RemoveTrailingSlash())
// e.Use(middleware.GzipWithConfig(middleware.GzipConfig{Level: 5}))
e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowHeaders: append([]string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, "token"}),
AllowCredentials: true,

@ -6,12 +6,12 @@ import (
"strconv"
"time"
"github.com/go-rust/im.dev/internal/ecode"
"github.com/go-rust/im.dev/internal/misc"
"github.com/go-rust/im.dev/internal/user"
"github.com/go-rust/im.dev/internal/utils"
"github.com/gocql/gocql"
"github.com/labstack/echo"
"github.com/thinkindev/im.dev/internal/ecode"
"github.com/thinkindev/im.dev/internal/misc"
"github.com/thinkindev/im.dev/internal/user"
"github.com/thinkindev/im.dev/internal/utils"
"go.uber.org/zap"
)

@ -9,12 +9,12 @@ import (
"strings"
"time"
"github.com/go-rust/im.dev/internal/ecode"
"github.com/go-rust/im.dev/internal/misc"
"github.com/go-rust/im.dev/internal/user"
"github.com/go-rust/im.dev/internal/utils"
"github.com/gocql/gocql"
"github.com/labstack/echo"
"github.com/thinkindev/im.dev/internal/ecode"
"github.com/thinkindev/im.dev/internal/misc"
"github.com/thinkindev/im.dev/internal/user"
"github.com/thinkindev/im.dev/internal/utils"
"go.uber.org/zap"
)

@ -4,10 +4,10 @@ import (
"fmt"
"net/http"
"github.com/go-rust/im.dev/internal/misc"
"github.com/go-rust/im.dev/internal/user"
"github.com/go-rust/im.dev/internal/utils"
"github.com/labstack/echo"
"github.com/thinkindev/im.dev/internal/misc"
"github.com/thinkindev/im.dev/internal/user"
"github.com/thinkindev/im.dev/internal/utils"
)
// Preview return the new review html of article

@ -6,10 +6,10 @@ import (
"sync"
"time"
"github.com/thinkindev/im.dev/internal/ecode"
"github.com/go-rust/im.dev/internal/ecode"
"github.com/go-rust/im.dev/internal/misc"
"github.com/labstack/echo"
"github.com/thinkindev/im.dev/internal/misc"
)
// Session contains user's info

@ -11,10 +11,10 @@ import (
"github.com/labstack/echo"
"go.uber.org/zap"
"github.com/thinkindev/im.dev/internal/ecode"
"github.com/thinkindev/im.dev/internal/misc"
"github.com/thinkindev/im.dev/internal/utils"
"github.com/thinkindev/im.dev/internal/utils/validate"
"github.com/go-rust/im.dev/internal/ecode"
"github.com/go-rust/im.dev/internal/misc"
"github.com/go-rust/im.dev/internal/utils"
"github.com/go-rust/im.dev/internal/utils/validate"
)
// InitUser insert preserve users

@ -14,7 +14,7 @@
package main
import "github.com/thinkindev/im.dev/cmd"
import "github.com/go-rust/im.dev/cmd"
func main() {
cmd.Execute()

@ -2,7 +2,7 @@ import React from 'react'
import { inject, observer } from 'mobx-react'
import { IntlProvider } from 'react-intl' /* react-intl imports */
import locale from '@library/locale'
import locale from '../../library/locale'
const Intl = inject('system')(observer((props) =>{
let {system} = props

@ -9,7 +9,7 @@ import stores from './store'
ReactDOM.render(
<Router>
<Provider {...stores}>
<App />
<App />
</Provider>
</Router>
, document.getElementById('root'))

@ -3,7 +3,7 @@ import { Layout, Icon, Badge, Avatar, Popover } from 'antd'
import { useHistory } from 'react-router-dom'
import { inject, observer } from 'mobx-react'
import { useMediaQuery } from 'react-responsive'
import { removeToken } from '@utils/auth'
import { removeToken } from '../../library/utils/auth'
import style from './index.module.less'
const { Header } = Layout
@ -67,7 +67,7 @@ const HeaderWrapper = inject('system', 'user')(observer((props) =>{
<div>
<Popover className={`${style.pointer}`} placement="bottomRight" content={userPopover}>
<Badge dot>
<Avatar icon="user" src={user.info.get('avatar')}/>
<Avatar icon="user" src={user.info.avatar}/>
</Badge>
</Popover>
</div>

@ -1,4 +1,4 @@
import Layout from '@/layouts/Layout'
export {
Layout
}
}

@ -21,6 +21,7 @@ let Index = inject('system')(observer((props:any) => {
// })
let history = useHistory()
useEffect(() => {
alert(11)
if(isEmpty(getToken())){
history.push('/login')
}

@ -29,6 +29,7 @@ function FormBox(props) {
delete res.data.token
storage.set('info', res.data)
props.setloading(true)
console.log(res.data)
user.setInfo(res.data)
setTimeout(() => {
props.setloading(false)

@ -31,15 +31,15 @@ class System{
this.drawer = !this.drawer
}
@action
setPrimary = (color) => {
setPrimary = (color:string) => {
this.primary = color
}
@action
setLocale = (locale) => {
setLocale = (locale:string) => {
this.locale = locale
}
@action
setLang = (lang) => {
setLang = (lang:string) => {
this.lang = lang
}
}

@ -1,12 +0,0 @@
import { action, observable } from 'mobx'
class User{
@observable info = new Map()
@action
setInfo = (info) => {
this.info.replace(info)
}
}
export default new User()

@ -0,0 +1,18 @@
import { action, observable } from 'mobx'
type UserInfo = {
address: string,
email: string,
tel: string,
avatar: string
}
class User{
@observable info:UserInfo = {address:'',email:'',tel:'',avatar:''}
@action
setInfo = (info:UserInfo) => {
this.info = info
}
}
export default new User()

@ -0,0 +1,5 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck
toml.test

@ -0,0 +1,15 @@
language: go
go:
- 1.1
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- tip
install:
- go install ./...
- go get github.com/BurntSushi/toml-test
script:
- export PATH="$PATH:$HOME/gopath/bin"
- make test

@ -0,0 +1,3 @@
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/v0.4.0/versions/en/toml-v0.4.0.md)

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2013 TOML authors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

@ -0,0 +1,19 @@
install:
go install ./...
test: install
go test -v
toml-test toml-test-decoder
toml-test -encoder toml-test-encoder
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

@ -0,0 +1,218 @@
## TOML parser and encoder for Go with reflection
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
Spec: https://github.com/toml-lang/toml
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Documentation: https://godoc.org/github.com/BurntSushi/toml
Installation:
```bash
go get github.com/BurntSushi/toml
```
Try the toml validator:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build Status](https://travis-ci.org/BurntSushi/toml.svg?branch=master)](https://travis-ci.org/BurntSushi/toml) [![GoDoc](https://godoc.org/github.com/BurntSushi/toml?status.svg)](https://godoc.org/github.com/BurntSushi/toml)
### Testing
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
and the encoder.
### Examples
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
```toml
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
```toml
some_key_NAME = "wat"
```
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
}
```
### Using the `encoding.TextUnmarshaler` interface
Here's an example that automatically parses duration strings into
`time.Duration` values:
```toml
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
```
Which can be decoded with:
```go
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
```
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

@ -0,0 +1,509 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"math"
"reflect"
"strings"
"time"
)
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
return err
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
}
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, indirect(rv))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
copy(context, md.context)
rv.Set(reflect.ValueOf(Primitive{
undecoded: data,
context: context,
}))
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
return md.unifyMap(data, rv)
case reflect.Array:
return md.unifyArray(data, rv)
case reflect.Slice:
return md.unifySlice(data, rv)
case reflect.String:
return md.unifyString(data, rv)
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("unsupported type %s", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
if mapping == nil {
return nil
}
return e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
}
for key, datum := range tmap {
var f *field
fields := cachedTypeFields(rv.Type())
for i := range fields {
ff := &fields[i]
if ff.name == key {
f = ff
break
}
if f == nil && strings.EqualFold(ff.name, key) {
f = ff
}
}
if f != nil {
subv := rv
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("cannot write unexported field %s.%s",
rv.Type().String(), f.name)
}
}
}
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
if tmap == nil {
return nil
}
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
}
n := datav.Len()
if rv.IsNil() || rv.Cap() < n {
rv.Set(reflect.MakeSlice(rv.Type(), n, n))
}
rv.SetLen(n)
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("value %d is out of range for int8", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("value %d is out of range for int16", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("value %d is out of range for int32", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("value %d is out of range for uint8", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("value %d is out of range for uint16", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("value %d is out of range for uint32", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
}
return nil
}
return badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
}
s = string(text)
case fmt.Stringer:
s = sdata.String()
case string:
s = sdata
case bool:
s = fmt.Sprintf("%v", sdata)
case int64:
s = fmt.Sprintf("%d", sdata)
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanSet() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
}
}
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("cannot load TOML value of type %T into a Go %s", data, expected)
}

@ -0,0 +1,121 @@
package toml
import "strings"
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
}
return k[i]
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
return md.keys
}
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
// and so no keys will be considered decoded.
//
// In this sense, the Undecoded keys correspond to keys in the TOML document
// that do not have a concrete type in your representation.
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
undecoded = append(undecoded, key)
}
}
return undecoded
}

@ -0,0 +1,27 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/toml-lang/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
*/
package toml

@ -0,0 +1,568 @@
package toml
import (
"bufio"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"toml: cannot encode array with mixed element types")
errArrayNilElement = errors.New(
"toml: cannot encode array with nil element")
errNonString = errors.New(
"toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New(
"toml: cannot encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"toml: TOML array element cannot contain a table")
errNoKey = errors.New(
"toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
if terr, ok := r.(tomlEncodeError); ok {
err = terr.error
return
}
panic(r)
}
}()
enc.encode(key, rv)
return nil
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
return
}
k := rv.Kind()
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
}
case reflect.Interface:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Map:
if rv.IsNil() {
return
}
enc.eTable(key, rv)
case reflect.Ptr:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("unexpected primitive type: %s", rv.Kind()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
}
return fstr
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
}
}
enc.wf("]")
}
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra newline between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
case reflect.Struct:
enc.eStruct(key, rv)
default:
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
}
// Sort keys so that we have deterministic output. And write keys directly
// underneath this key first, before writing sub-structs or sub-maps.
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
continue
}
enc.encode(key.add(mapKey), mrv)
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexported fields
if f.PkgPath != "" && !f.Anonymous {
continue
}
frv := rv.Field(i)
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
// Treat anonymous struct fields with
// tag names as though they are not
// anonymous, like encoding/json does.
if getOptions(f.Tag).name == "" {
addFields(t, frv, f.Index)
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct &&
getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), f.Index)
}
continue
}
// Fall through to the normal field encoding logic below
// for non-struct anonymous fields.
}
}
if typeIsHash(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
continue
}
opts := getOptions(sft.Tag)
if opts.skip {
continue
}
keyName := sft.Name
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && isEmpty(sf) {
continue
}
if opts.omitzero && isZero(sf) {
continue
}
enc.encode(key.add(keyName), sf)
}
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInteger
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
}
return tomlArray
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
}
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
type tagOptions struct {
skip bool // "-"
name string
omitempty bool
omitzero bool
}
func getOptions(tag reflect.StructTag) tagOptions {
t := tag.Get("toml")
if t == "-" {
return tagOptions{skip: true}
}
var opts tagOptions
parts := strings.Split(t, ",")
opts.name = parts[0]
for _, s := range parts[1:] {
switch s {
case "omitempty":
opts.omitempty = true
case "omitzero":
opts.omitzero = true
}
}
return opts
}
func isZero(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return rv.Uint() == 0
case reflect.Float32, reflect.Float64:
return rv.Float() == 0.0
}
return false
}
func isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return rv.Len() == 0
case reflect.Bool:
return !rv.Bool()
}
return false
}
func (enc *Encoder) newline() {
if enc.hasWritten {
enc.wf("\n")
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
encPanic(err)
}
enc.hasWritten = true
}
func (enc *Encoder) indentStr(key Key) string {
return strings.Repeat(enc.Indent, len(key)-1)
}
func encPanic(err error) {
panic(tomlEncodeError{err})
}
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
return v
}
}
func isNil(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

@ -0,0 +1,19 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

@ -0,0 +1,18 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

@ -0,0 +1,953 @@
package toml
import (
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemRawString
itemMultilineString
itemRawMultilineString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
itemInlineTableStart
itemInlineTableEnd
)
const (
eof = 0
comma = ','
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
inlineTableStart = '{'
inlineTableEnd = '}'
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
line int
state stateFn
items chan item
// Allow for backing up up to three runes.
// This is necessary because TOML contains 3-rune tokens (""" and ''').
prevWidths [3]int
nprev int // how many of prevWidths are in use
// If we emit an eof, we can still back up, but it is not OK to call
// next again.
atEOF bool
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
}
func lex(input string) *lexer {
lx := &lexer{
input: input,
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) emitTrim(typ itemType) {
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.atEOF {
panic("next called after EOF")
}
if lx.pos >= len(lx.input) {
lx.atEOF = true
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
lx.prevWidths[2] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[0]
if lx.nprev < 3 {
lx.nprev++
}
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.prevWidths[0] = w
lx.pos += w
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only twice between calls to next.
func (lx *lexer) backup() {
if lx.atEOF {
lx.atEOF = false
return
}
if lx.nprev < 1 {
panic("backed up too far")
}
w := lx.prevWidths[0]
lx.prevWidths[0] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[2]
lx.nprev--
lx.pos -= w
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// skip ignores all input that matches the given predicate.
func (lx *lexer) skip(pred func(rune) bool) {
for {
r := lx.next()
if pred(r) {
continue
}
lx.backup()
lx.ignore()
return
}
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (newlines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("unexpected EOF")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a newline. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a newline for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.emit(itemEOF)
return nil
}
return lx.errorf("expected a top-level item to end with a newline, "+
"comment, or EOF, but got %q instead", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("expected end of table array name delimiter %q, "+
"but got %q instead", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("unexpected end of table name " +
"(table names cannot be empty)")
case r == tableSep:
return lx.errorf("unexpected table separator " +
"(table names cannot be empty)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
return lexValue // reuse string lexing
default:
return lexBareTableName
}
}
// lexBareTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
r := lx.next()
if isBareKeyChar(r) {
return lexBareTableName
}
lx.backup()
lx.emit(itemText)
return lexTableNameEnd
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
case r == tableSep:
lx.ignore()
return lexTableNameStart
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("expected '.' or ']' to end table name, "+
"but got %q instead", r)
}
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("unexpected key separator %q", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.emit(itemKeyStart)
lx.push(lexKeyEnd)
return lexValue // reuse string lexing
default:
lx.ignore()
lx.emit(itemKeyStart)
return lexBareKey
}
}
// lexBareKey consumes the text of a bare key. Assumes that the first character
// (which is not whitespace) has not yet been consumed.
func lexBareKey(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.backup()
lx.emit(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emit(itemText)
return lexKeyEnd
default:
return lx.errorf("bare keys cannot contain %q", r)
}
}
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
// separator).
func lexKeyEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case r == keySep:
return lexSkip(lx, lexValue)
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("expected key separator %q, but got %q instead",
keySep, r)
}
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT newlines.
// In array syntax, the array states are responsible for ignoring newlines.
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexValue)
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
}
switch r {
case arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case inlineTableStart:
lx.ignore()
lx.emit(itemInlineTableStart)
return lexInlineTableValue
case stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
return lexMultilineString
}
lx.backup()
}
lx.ignore() // ignore the '"'
return lexString
case rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
return lexMultilineRawString
}
lx.backup()
}
lx.ignore() // ignore the "'"
return lexRawString
case '+', '-':
return lexNumberStart
case '.': // special error case, be kind to users
return lx.errorf("floats must start with a digit, not '.'")
}
if unicode.IsLetter(r) {
// Be permissive here; lexBool will give a nice error if the
// user wrote something like
// x = foo
// (i.e. not 'true' or 'false' but is something else word-like.)
lx.backup()
return lexBool
}
return lx.errorf("expected value but found %q instead", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and newlines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == arrayEnd:
// NOTE(caleb): The spec isn't clear about whether you can have
// a trailing comma or not, so we'll allow it.
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes everything between the end of an array value and
// the next value (or the end of the array): it ignores whitespace and newlines
// and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf(
"expected a comma or array terminator %q, but got %q instead",
arrayEnd, r,
)
}
// lexArrayEnd finishes the lexing of an array.
// It assumes that a ']' has just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexInlineTableValue consumes one key/value pair in an inline table.
// It assumes that '{' or ',' have already been consumed. Whitespace is ignored.
func lexInlineTableValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValue)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == inlineTableEnd:
return lexInlineTableEnd
}
lx.backup()
lx.push(lexInlineTableValueEnd)
return lexKeyStart
}
// lexInlineTableValueEnd consumes everything between the end of an inline table
// key/value pair and the next pair (or the end of the table):
// it ignores whitespace and expects either a ',' or a '}'.
func lexInlineTableValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValueEnd)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexInlineTableValue
case r == inlineTableEnd:
return lexInlineTableEnd
}
return lx.errorf("expected a comma or an inline table terminator %q, "+
"but got %q instead", inlineTableEnd, r)
}
// lexInlineTableEnd finishes the lexing of an inline table.
// It assumes that a '}' has just been consumed.
func lexInlineTableEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemInlineTableEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("strings cannot contain newlines")
case r == '\\':
lx.push(lexString)
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case '\\':
return lexMultilineStringEscape
case stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineString
}
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
// It assumes that the beginning "'" has already been consumed and ignored.
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("strings cannot contain newlines")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexRawString
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'''" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemRawMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineRawString
}
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
return lexMultilineString
}
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '\\':
return lx.pop()
case 'u':
return lexShortUnicodeEscape
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("invalid escape character %q; only the following "+
"escape characters are allowed: "+
`\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(`expected four hexadecimal digits after '\u', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
}
func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(`expected eight hexadecimal digits after '\U', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either an integer, a float, or datetime.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumberOrDate
}
switch r {
case '_':
return lexNumber
case 'e', 'E':
return lexFloat
case '.':
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
// lexNumberOrDate consumes either an integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumberOrDate
}
switch r {
case '-':
return lexDatetime
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDatetime consumes a Datetime, to a first approximation.
// The parser validates that it matches one of the accepted formats.
func lexDatetime(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexDatetime
}
switch r {
case '-', 'T', ':', '.', 'Z', '+':
return lexDatetime
}
lx.backup()
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that a sign
// has already been read, but that *no* digits have been consumed.
// lexNumberStart will move to the appropriate integer or float states.
func lexNumberStart(lx *lexer) stateFn {
// We MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumber
}
switch r {
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloat consumes the elements of a float. It allows any sequence of
// float-like characters, so floats emitted by the lexer are only a first
// approximation and must be validated by the parser.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
switch r {
case '_', '.', '-', '+', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexBool consumes a bool string: 'true' or 'false.
func lexBool(lx *lexer) stateFn {
var rs []rune
for {
r := lx.next()
if !unicode.IsLetter(r) {
lx.backup()
break
}
rs = append(rs, r)
}
s := string(rs)
switch s {
case "true", "false":
lx.emit(itemBool)
return lx.pop()
}
return lx.errorf("expected value but found %q instead", s)
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first newline character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' ||
r == '-'
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}

@ -0,0 +1,592 @@
package toml
import (
"fmt"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panicf("%s", it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// Gets a string for a key (or part of a key in a table name).
func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
if !numUnderscoresOK(it.val) {
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseInt(val, 10, 64)
if err != nil {
// Distinguish integer values. Normally, it'd be a bug if the lexer
// provides an invalid integer, but it's possible that the number is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
return true
}
return false
})
for _, part := range parts {
if !numUnderscoresOK(part) {
p.panicf("Invalid float %q: underscores must be "+
"surrounded by digits", it.val)
}
}
if !numPeriodsOK(it.val) {
// As a special case, numbers like '123.' or '1.e2',
// which are valid as far as Go/strconv are concerned,
// must be rejected because TOML says that a fractional
// part consists of '.' followed by 1+ digits.
p.panicf("Invalid float %q: '.' must be followed "+
"by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseFloat(val, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.panicf("Invalid float value: %q", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
var t time.Time
var ok bool
var err error
for _, format := range []string{
"2006-01-02T15:04:05Z07:00",
"2006-01-02T15:04:05",
"2006-01-02",
} {
t, err = time.ParseInLocation(format, it.val, time.Local)
if err == nil {
ok = true
break
}
}
if !ok {
p.panicf("Invalid TOML Datetime: %q.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
case itemInlineTableStart:
var (
hash = make(map[string]interface{})
outerContext = p.context
outerKey = p.currentKey
)
p.context = append(p.context, p.currentKey)
p.currentKey = ""
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
if it.typ != itemKeyStart {
p.bug("Expected key start but instead found %q, around line %d",
it.val, p.approxLine)
}
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
// retrieve key
k := p.next()
p.approxLine = k.line
kname := p.keyString(k)
// retrieve value
p.currentKey = kname
val, typ := p.value(p.next())
// make sure we keep metadata up to date
p.setType(kname, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[kname] = val
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// numUnderscoresOK checks whether each underscore in s is surrounded by
// characters that are not underscores.
func numUnderscoresOK(s string) bool {
accept := false
for _, r := range s {
if r == '_' {
if !accept {
return false
}
accept = false
continue
}
accept = true
}
return accept
}
// numPeriodsOK checks whether every period in s is followed by a digit.
func numPeriodsOK(s string) bool {
period := false
for _, r := range s {
if period && !isDigit(r) {
return false
}
period = r == '.'
}
return !period
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
return s
}
return s[1:]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
}
}
return strings.Join(esc, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
continue
}
r += 1
if r >= len(s) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
replaced = append(replaced, escaped)
r += 9
}
}
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
if !utf8.ValidRune(rune(hex)) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

@ -0,0 +1 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

@ -0,0 +1,91 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
if t1 == nil || t2 == nil {
return false
}
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemMultilineString:
return tomlString
case itemRawString:
return tomlString
case itemRawMultilineString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

@ -0,0 +1,242 @@
package toml
// Struct field handling is adapted from code in encoding/json:
//
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the Go distribution.
import (
"reflect"
"sort"
"sync"
)
// A field represents a single field found in a struct.
type field struct {
name string // the name of the field (`toml` tag included)
tag bool // whether field has a `toml` tag
index []int // represents the depth of an anonymous field
typ reflect.Type // the type of the field
}
// byName sorts field by name, breaking ties with depth,
// then breaking ties with "name came from toml tag", then
// breaking ties with index sequence.
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that TOML should recognize for the given
// type. The algorithm is breadth-first search over the set of structs to
// include - the top struct and then any reachable anonymous structs.
func typeFields(t reflect.Type) []field {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.PkgPath != "" && !sf.Anonymous { // unexported
continue
}
opts := getOptions(sf.Tag)
if opts.skip {
continue
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
// Follow pointer.
ft = ft.Elem()
}
// Record found field and index sequence.
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := opts.name != ""
name := opts.name
if name == "" {
name = sf.Name
}
fields = append(fields, field{name, tagged, index, ft})
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
f := field{name: ft.Name(), index: index, typ: ft}
next = append(next, f)
}
}
}
}
sort.Sort(byName(fields))
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with TOML tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
return fields
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// TOML tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order. The winner
// must therefore be one with the shortest index length. Drop all
// longer entries, which is easy: just truncate the slice.
length := len(fields[0].index)
tagged := -1 // Index of first tagged field.
for i, f := range fields {
if len(f.index) > length {
fields = fields[:i]
break
}
if f.tag {
if tagged >= 0 {
// Multiple tagged fields at the same level: conflict.
// Return no field.
return field{}, false
}
tagged = i
}
}
if tagged >= 0 {
return fields[tagged], true
}
// All remaining fields have the same length. If there's more than one,
// we have a conflict (two fields named "X" at the same level) and we
// return no field.
if len(fields) > 1 {
return field{}, false
}
return fields[0], true
}
var fieldCache struct {
sync.RWMutex
m map[reflect.Type][]field
}
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) []field {
fieldCache.RLock()
f := fieldCache.m[t]
fieldCache.RUnlock()
if f != nil {
return f
}
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = typeFields(t)
if f == nil {
f = []field{}
}
fieldCache.Lock()
if fieldCache.m == nil {
fieldCache.m = map[reflect.Type][]field{}
}
fieldCache.m[t] = f
fieldCache.Unlock()
return f
}

@ -0,0 +1,4 @@
.DS_Store
bin

@ -0,0 +1,13 @@
language: go
script:
- go vet ./...
- go test -v ./...
go:
- 1.3
- 1.4
- 1.5
- 1.6
- 1.7
- tip

@ -0,0 +1,8 @@
Copyright (c) 2012 Dave Grijalva
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,97 @@
## Migration Guide from v2 -> v3
Version 3 adds several new, frequently requested features. To do so, it introduces a few breaking changes. We've worked to keep these as minimal as possible. This guide explains the breaking changes and how you can quickly update your code.
### `Token.Claims` is now an interface type
The most requested feature from the 2.0 verison of this library was the ability to provide a custom type to the JSON parser for claims. This was implemented by introducing a new interface, `Claims`, to replace `map[string]interface{}`. We also included two concrete implementations of `Claims`: `MapClaims` and `StandardClaims`.
`MapClaims` is an alias for `map[string]interface{}` with built in validation behavior. It is the default claims type when using `Parse`. The usage is unchanged except you must type cast the claims property.
The old example for parsing a token looked like this..
```go
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
}
```
is now directly mapped to...
```go
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
claims := token.Claims.(jwt.MapClaims)
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
}
```
`StandardClaims` is designed to be embedded in your custom type. You can supply a custom claims type with the new `ParseWithClaims` function. Here's an example of using a custom claims type.
```go
type MyCustomClaims struct {
User string
*StandardClaims
}
if token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, keyLookupFunc); err == nil {
claims := token.Claims.(*MyCustomClaims)
fmt.Printf("Token for user %v expires %v", claims.User, claims.StandardClaims.ExpiresAt)
}
```
### `ParseFromRequest` has been moved
To keep this library focused on the tokens without becoming overburdened with complex request processing logic, `ParseFromRequest` and its new companion `ParseFromRequestWithClaims` have been moved to a subpackage, `request`. The method signatues have also been augmented to receive a new argument: `Extractor`.
`Extractors` do the work of picking the token string out of a request. The interface is simple and composable.
This simple parsing example:
```go
if token, err := jwt.ParseFromRequest(tokenString, req, keyLookupFunc); err == nil {
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
}
```
is directly mapped to:
```go
if token, err := request.ParseFromRequest(req, request.OAuth2Extractor, keyLookupFunc); err == nil {
claims := token.Claims.(jwt.MapClaims)
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
}
```
There are several concrete `Extractor` types provided for your convenience:
* `HeaderExtractor` will search a list of headers until one contains content.
* `ArgumentExtractor` will search a list of keys in request query and form arguments until one contains content.
* `MultiExtractor` will try a list of `Extractors` in order until one returns content.
* `AuthorizationHeaderExtractor` will look in the `Authorization` header for a `Bearer` token.
* `OAuth2Extractor` searches the places an OAuth2 token would be specified (per the spec): `Authorization` header and `access_token` argument
* `PostExtractionFilter` wraps an `Extractor`, allowing you to process the content before it's parsed. A simple example is stripping the `Bearer ` text from a header
### RSA signing methods no longer accept `[]byte` keys
Due to a [critical vulnerability](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/), we've decided the convenience of accepting `[]byte` instead of `rsa.PublicKey` or `rsa.PrivateKey` isn't worth the risk of misuse.
To replace this behavior, we've added two helper methods: `ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error)` and `ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error)`. These are just simple helpers for unpacking PEM encoded PKCS1 and PKCS8 keys. If your keys are encoded any other way, all you need to do is convert them to the `crypto/rsa` package's types.
```go
func keyLookupFunc(*Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
// Look up key
key, err := lookupPublicKey(token.Header["kid"])
if err != nil {
return nil, err
}
// Unpack key from PEM encoded PKCS8
return jwt.ParseRSAPublicKeyFromPEM(key)
}
```

@ -0,0 +1,100 @@
# jwt-go
[![Build Status](https://travis-ci.org/dgrijalva/jwt-go.svg?branch=master)](https://travis-ci.org/dgrijalva/jwt-go)
[![GoDoc](https://godoc.org/github.com/dgrijalva/jwt-go?status.svg)](https://godoc.org/github.com/dgrijalva/jwt-go)
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
**NEW VERSION COMING:** There have been a lot of improvements suggested since the version 3.0.0 released in 2016. I'm working now on cutting two different releases: 3.2.0 will contain any non-breaking changes or enhancements. 4.0.0 will follow shortly which will include breaking changes. See the 4.0.0 milestone to get an idea of what's coming. If you have other ideas, or would like to participate in 4.0.0, now's the time. If you depend on this library and don't want to be interrupted, I recommend you use your dependency mangement tool to pin to version 3.
**SECURITY NOTICE:** Some older versions of Go have a security issue in the cryotp/elliptic. Recommendation is to upgrade to at least 1.8.3. See issue #216 for more detail.
**SECURITY NOTICE:** It's important that you [validate the `alg` presented is what you expect](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key types match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided.
## What the heck is a JWT?
JWT.io has [a great introduction](https://jwt.io/introduction) to JSON Web Tokens.
In short, it's a signed JSON object that does something useful (for example, authentication). It's commonly used for `Bearer` tokens in Oauth 2. A token is made of three parts, separated by `.`'s. The first two parts are JSON objects, that have been [base64url](http://tools.ietf.org/html/rfc4648) encoded. The last part is the signature, encoded the same way.
The first part is called the header. It contains the necessary information for verifying the last part, the signature. For example, which encryption method was used for signing and what key was used.
The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-jones-json-web-token.html) for information about reserved keys and the proper way to add your own.
## What's in the box?
This library supports the parsing and verification as well as the generation and signing of JWTs. Current supported signing algorithms are HMAC SHA, RSA, RSA-PSS, and ECDSA, though hooks are present for adding your own.
## Examples
See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) for examples of usage:
* [Simple example of parsing and validating a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-Parse--Hmac)
* [Simple example of building and signing a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-New--Hmac)
* [Directory of Examples](https://godoc.org/github.com/dgrijalva/jwt-go#pkg-examples)
## Extensions
This library publishes all the necessary components for adding your own signing methods. Simply implement the `SigningMethod` interface and register a factory method using `RegisterSigningMethod`.
Here's an example of an extension that integrates with the Google App Engine signing tools: https://github.com/someone1/gcp-jwt-go
## Compliance
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
* In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.
## Project Status & Versioning
This library is considered production ready. Feedback and feature requests are appreciated. The API should be considered stable. There should be very few backwards-incompatible changes outside of major version updates (and only with good reason).
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v3`. It will do the right thing WRT semantic versioning.
**BREAKING CHANGES:***
* Version 3.0.0 includes _a lot_ of changes from the 2.x line, including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
## Usage Tips
### Signing vs Encryption
A token is simply a JSON object that is signed by its author. this tells you exactly two things about the data:
* The author of the token was in the possession of the signing secret
* The data has not been modified since it was signed
It's important to know that JWT does not provide encryption, which means anyone who has access to the token can read its contents. If you need to protect (encrypt) the data, there is a companion spec, `JWE`, that provides this functionality. JWE is currently outside the scope of this library.
### Choosing a Signing Method
There are several signing methods available, and you should probably take the time to learn about the various options before choosing one. The principal design decision is most likely going to be symmetric vs asymmetric.
Symmetric signing methods, such as HSA, use only a single secret. This is probably the simplest signing method to use since any `[]byte` can be used as a valid secret. They are also slightly computationally faster to use, though this rarely is enough to matter. Symmetric signing methods work the best when both producers and consumers of tokens are trusted, or even the same system. Since the same secret is used to both sign and validate tokens, you can't easily distribute the key for validation.
Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification.
### Signing Methods and Key Types
Each signing method expects a different object type for its signing keys. See the package documentation for details. Here are the most common ones:
* The [HMAC signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodHMAC) (`HS256`,`HS384`,`HS512`) expect `[]byte` values for signing and validation
* The [RSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodRSA) (`RS256`,`RS384`,`RS512`) expect `*rsa.PrivateKey` for signing and `*rsa.PublicKey` for validation
* The [ECDSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodECDSA) (`ES256`,`ES384`,`ES512`) expect `*ecdsa.PrivateKey` for signing and `*ecdsa.PublicKey` for validation
### JWT and OAuth
It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication.
Without going too far down the rabbit hole, here's a description of the interaction of these technologies:
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
* OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token.
* Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL.
## More
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in the documentation.

@ -0,0 +1,118 @@
## `jwt-go` Version History
#### 3.2.0
* Added method `ParseUnverified` to allow users to split up the tasks of parsing and validation
* HMAC signing method returns `ErrInvalidKeyType` instead of `ErrInvalidKey` where appropriate
* Added options to `request.ParseFromRequest`, which allows for an arbitrary list of modifiers to parsing behavior. Initial set include `WithClaims` and `WithParser`. Existing usage of this function will continue to work as before.
* Deprecated `ParseFromRequestWithClaims` to simplify API in the future.
#### 3.1.0
* Improvements to `jwt` command line tool
* Added `SkipClaimsValidation` option to `Parser`
* Documentation updates
#### 3.0.0
* **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code
* Dropped support for `[]byte` keys when using RSA signing methods. This convenience feature could contribute to security vulnerabilities involving mismatched key types with signing methods.
* `ParseFromRequest` has been moved to `request` subpackage and usage has changed
* The `Claims` property on `Token` is now type `Claims` instead of `map[string]interface{}`. The default value is type `MapClaims`, which is an alias to `map[string]interface{}`. This makes it possible to use a custom type when decoding claims.
* Other Additions and Changes
* Added `Claims` interface type to allow users to decode the claims into a custom type
* Added `ParseWithClaims`, which takes a third argument of type `Claims`. Use this function instead of `Parse` if you have a custom type you'd like to decode into.
* Dramatically improved the functionality and flexibility of `ParseFromRequest`, which is now in the `request` subpackage
* Added `ParseFromRequestWithClaims` which is the `FromRequest` equivalent of `ParseWithClaims`
* Added new interface type `Extractor`, which is used for extracting JWT strings from http requests. Used with `ParseFromRequest` and `ParseFromRequestWithClaims`.
* Added several new, more specific, validation errors to error type bitmask
* Moved examples from README to executable example files
* Signing method registry is now thread safe
* Added new property to `ValidationError`, which contains the raw error returned by calls made by parse/verify (such as those returned by keyfunc or json parser)
#### 2.7.0
This will likely be the last backwards compatible release before 3.0.0, excluding essential bug fixes.
* Added new option `-show` to the `jwt` command that will just output the decoded token without verifying
* Error text for expired tokens includes how long it's been expired
* Fixed incorrect error returned from `ParseRSAPublicKeyFromPEM`
* Documentation updates
#### 2.6.0
* Exposed inner error within ValidationError
* Fixed validation errors when using UseJSONNumber flag
* Added several unit tests
#### 2.5.0
* Added support for signing method none. You shouldn't use this. The API tries to make this clear.
* Updated/fixed some documentation
* Added more helpful error message when trying to parse tokens that begin with `BEARER `
#### 2.4.0
* Added new type, Parser, to allow for configuration of various parsing parameters
* You can now specify a list of valid signing methods. Anything outside this set will be rejected.
* You can now opt to use the `json.Number` type instead of `float64` when parsing token JSON
* Added support for [Travis CI](https://travis-ci.org/dgrijalva/jwt-go)
* Fixed some bugs with ECDSA parsing
#### 2.3.0
* Added support for ECDSA signing methods
* Added support for RSA PSS signing methods (requires go v1.4)
#### 2.2.0
* Gracefully handle a `nil` `Keyfunc` being passed to `Parse`. Result will now be the parsed token and an error, instead of a panic.
#### 2.1.0
Backwards compatible API change that was missed in 2.0.0.
* The `SignedString` method on `Token` now takes `interface{}` instead of `[]byte`
#### 2.0.0
There were two major reasons for breaking backwards compatibility with this update. The first was a refactor required to expand the width of the RSA and HMAC-SHA signing implementations. There will likely be no required code changes to support this change.
The second update, while unfortunately requiring a small change in integration, is required to open up this library to other signing methods. Not all keys used for all signing methods have a single standard on-disk representation. Requiring `[]byte` as the type for all keys proved too limiting. Additionally, this implementation allows for pre-parsed tokens to be reused, which might matter in an application that parses a high volume of tokens with a small set of keys. Backwards compatibilty has been maintained for passing `[]byte` to the RSA signing methods, but they will also accept `*rsa.PublicKey` and `*rsa.PrivateKey`.
It is likely the only integration change required here will be to change `func(t *jwt.Token) ([]byte, error)` to `func(t *jwt.Token) (interface{}, error)` when calling `Parse`.
* **Compatibility Breaking Changes**
* `SigningMethodHS256` is now `*SigningMethodHMAC` instead of `type struct`
* `SigningMethodRS256` is now `*SigningMethodRSA` instead of `type struct`
* `KeyFunc` now returns `interface{}` instead of `[]byte`
* `SigningMethod.Sign` now takes `interface{}` instead of `[]byte` for the key
* `SigningMethod.Verify` now takes `interface{}` instead of `[]byte` for the key
* Renamed type `SigningMethodHS256` to `SigningMethodHMAC`. Specific sizes are now just instances of this type.
* Added public package global `SigningMethodHS256`
* Added public package global `SigningMethodHS384`
* Added public package global `SigningMethodHS512`
* Renamed type `SigningMethodRS256` to `SigningMethodRSA`. Specific sizes are now just instances of this type.
* Added public package global `SigningMethodRS256`
* Added public package global `SigningMethodRS384`
* Added public package global `SigningMethodRS512`
* Moved sample private key for HMAC tests from an inline value to a file on disk. Value is unchanged.
* Refactored the RSA implementation to be easier to read
* Exposed helper methods `ParseRSAPrivateKeyFromPEM` and `ParseRSAPublicKeyFromPEM`
#### 1.0.2
* Fixed bug in parsing public keys from certificates
* Added more tests around the parsing of keys for RS256
* Code refactoring in RS256 implementation. No functional changes
#### 1.0.1
* Fixed panic if RS256 signing method was passed an invalid key
#### 1.0.0
* First versioned release
* API stabilized
* Supports creating, signing, parsing, and validating JWT tokens
* Supports RS256 and HS256 signing methods

@ -0,0 +1,134 @@
package jwt
import (
"crypto/subtle"
"fmt"
"time"
)
// For a type to be a Claims object, it must just have a Valid method that determines
// if the token is invalid for any supported reason
type Claims interface {
Valid() error
}
// Structured version of Claims Section, as referenced at
// https://tools.ietf.org/html/rfc7519#section-4.1
// See examples for how to use this with your own claim types
type StandardClaims struct {
Audience string `json:"aud,omitempty"`
ExpiresAt int64 `json:"exp,omitempty"`
Id string `json:"jti,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
Issuer string `json:"iss,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
Subject string `json:"sub,omitempty"`
}
// Validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (c StandardClaims) Valid() error {
vErr := new(ValidationError)
now := TimeFunc().Unix()
// The claims below are optional, by default, so if they are set to the
// default value in Go, let's not fail the verification for them.
if c.VerifyExpiresAt(now, false) == false {
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
vErr.Errors |= ValidationErrorExpired
}
if c.VerifyIssuedAt(now, false) == false {
vErr.Inner = fmt.Errorf("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt
}
if c.VerifyNotBefore(now, false) == false {
vErr.Inner = fmt.Errorf("token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
if vErr.valid() {
return nil
}
return vErr
}
// Compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool {
return verifyAud(c.Audience, cmp, req)
}
// Compares the exp claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool {
return verifyExp(c.ExpiresAt, cmp, req)
}
// Compares the iat claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool {
return verifyIat(c.IssuedAt, cmp, req)
}
// Compares the iss claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool {
return verifyIss(c.Issuer, cmp, req)
}
// Compares the nbf claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
return verifyNbf(c.NotBefore, cmp, req)
}
// ----- helpers
func verifyAud(aud string, cmp string, required bool) bool {
if aud == "" {
return !required
}
if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 {
return true
} else {
return false
}
}
func verifyExp(exp int64, now int64, required bool) bool {
if exp == 0 {
return !required
}
return now <= exp
}
func verifyIat(iat int64, now int64, required bool) bool {
if iat == 0 {
return !required
}
return now >= iat
}
func verifyIss(iss string, cmp string, required bool) bool {
if iss == "" {
return !required
}
if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 {
return true
} else {
return false
}
}
func verifyNbf(nbf int64, now int64, required bool) bool {
if nbf == 0 {
return !required
}
return now >= nbf
}

@ -0,0 +1,4 @@
// Package jwt is a Go implementation of JSON Web Tokens: http://self-issued.info/docs/draft-jones-json-web-token.html
//
// See README.md for more info.
package jwt

@ -0,0 +1,148 @@
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"errors"
"math/big"
)
var (
// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
)
// Implements the ECDSA family of signing methods signing methods
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
type SigningMethodECDSA struct {
Name string
Hash crypto.Hash
KeySize int
CurveBits int
}
// Specific instances for EC256 and company
var (
SigningMethodES256 *SigningMethodECDSA
SigningMethodES384 *SigningMethodECDSA
SigningMethodES512 *SigningMethodECDSA
)
func init() {
// ES256
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
return SigningMethodES256
})
// ES384
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
return SigningMethodES384
})
// ES512
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
return SigningMethodES512
})
}
func (m *SigningMethodECDSA) Alg() string {
return m.Name
}
// Implements the Verify method from SigningMethod
// For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
case *ecdsa.PublicKey:
ecdsaKey = k
default:
return ErrInvalidKeyType
}
if len(sig) != 2*m.KeySize {
return ErrECDSAVerification
}
r := big.NewInt(0).SetBytes(sig[:m.KeySize])
s := big.NewInt(0).SetBytes(sig[m.KeySize:])
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
return nil
} else {
return ErrECDSAVerification
}
}
// Implements the Sign method from SigningMethod
// For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
// Get the key
var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
ecdsaKey = k
default:
return "", ErrInvalidKeyType
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return r, s
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
curveBits := ecdsaKey.Curve.Params().BitSize
if m.CurveBits != curveBits {
return "", ErrInvalidKey
}
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes += 1
}
// We serialize the outpus (r and s) into big-endian byte arrays and pad
// them with zeros on the left to make sure the sizes work out. Both arrays
// must be keyBytes long, and the output must be 2*keyBytes long.
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
out := append(rBytesPadded, sBytesPadded...)
return EncodeSegment(out), nil
} else {
return "", err
}
}

@ -0,0 +1,67 @@
package jwt
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrNotECPublicKey = errors.New("Key is not a valid ECDSA public key")
ErrNotECPrivateKey = errors.New("Key is not a valid ECDSA private key")
)
// Parse PEM encoded Elliptic Curve Private Key Structure
func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
return nil, err
}
var pkey *ecdsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*ecdsa.PrivateKey); !ok {
return nil, ErrNotECPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 public key
func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
var pkey *ecdsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*ecdsa.PublicKey); !ok {
return nil, ErrNotECPublicKey
}
return pkey, nil
}

@ -0,0 +1,59 @@
package jwt
import (
"errors"
)
// Error constants
var (
ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
)
// The errors that might occur when parsing and validating a token
const (
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed
ValidationErrorUnverifiable // Token could not be verified because of signing problems
ValidationErrorSignatureInvalid // Signature validation failed
// Standard Claim validation errors
ValidationErrorAudience // AUD validation failed
ValidationErrorExpired // EXP validation failed
ValidationErrorIssuedAt // IAT validation failed
ValidationErrorIssuer // ISS validation failed
ValidationErrorNotValidYet // NBF validation failed
ValidationErrorId // JTI validation failed
ValidationErrorClaimsInvalid // Generic claims validation error
)
// Helper for constructing a ValidationError with a string error message
func NewValidationError(errorText string, errorFlags uint32) *ValidationError {
return &ValidationError{
text: errorText,
Errors: errorFlags,
}
}
// The error from Parse if token is not valid
type ValidationError struct {
Inner error // stores the error returned by external dependencies, i.e.: KeyFunc
Errors uint32 // bitfield. see ValidationError... constants
text string // errors that do not have a valid error just have text
}
// Validation error is an error type
func (e ValidationError) Error() string {
if e.Inner != nil {
return e.Inner.Error()
} else if e.text != "" {
return e.text
} else {
return "token is invalid"
}
}
// No errors
func (e *ValidationError) valid() bool {
return e.Errors == 0
}

@ -0,0 +1,95 @@
package jwt
import (
"crypto"
"crypto/hmac"
"errors"
)
// Implements the HMAC-SHA family of signing methods signing methods
// Expects key type of []byte for both signing and validation
type SigningMethodHMAC struct {
Name string
Hash crypto.Hash
}
// Specific instances for HS256 and company
var (
SigningMethodHS256 *SigningMethodHMAC
SigningMethodHS384 *SigningMethodHMAC
SigningMethodHS512 *SigningMethodHMAC
ErrSignatureInvalid = errors.New("signature is invalid")
)
func init() {
// HS256
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
return SigningMethodHS256
})
// HS384
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
return SigningMethodHS384
})
// HS512
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
return SigningMethodHS512
})
}
func (m *SigningMethodHMAC) Alg() string {
return m.Name
}
// Verify the signature of HSXXX tokens. Returns nil if the signature is valid.
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}
// Decode signature, for comparison
sig, err := DecodeSegment(signature)
if err != nil {
return err
}
// Can we use the specified hashing method?
if !m.Hash.Available() {
return ErrHashUnavailable
}
// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
return ErrSignatureInvalid
}
// No validation errors. Signature is good.
return nil
}
// Implements the Sign method from SigningMethod for this signing method.
// Key must be []byte
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
if keyBytes, ok := key.([]byte); ok {
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
return EncodeSegment(hasher.Sum(nil)), nil
}
return "", ErrInvalidKeyType
}

@ -0,0 +1,94 @@
package jwt
import (
"encoding/json"
"errors"
// "fmt"
)
// Claims type that uses the map[string]interface{} for JSON decoding
// This is the default claims type if you don't supply one
type MapClaims map[string]interface{}
// Compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
aud, _ := m["aud"].(string)
return verifyAud(aud, cmp, req)
}
// Compares the exp claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
switch exp := m["exp"].(type) {
case float64:
return verifyExp(int64(exp), cmp, req)
case json.Number:
v, _ := exp.Int64()
return verifyExp(v, cmp, req)
}
return req == false
}
// Compares the iat claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
switch iat := m["iat"].(type) {
case float64:
return verifyIat(int64(iat), cmp, req)
case json.Number:
v, _ := iat.Int64()
return verifyIat(v, cmp, req)
}
return req == false
}
// Compares the iss claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
iss, _ := m["iss"].(string)
return verifyIss(iss, cmp, req)
}
// Compares the nbf claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
switch nbf := m["nbf"].(type) {
case float64:
return verifyNbf(int64(nbf), cmp, req)
case json.Number:
v, _ := nbf.Int64()
return verifyNbf(v, cmp, req)
}
return req == false
}
// Validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (m MapClaims) Valid() error {
vErr := new(ValidationError)
now := TimeFunc().Unix()
if m.VerifyExpiresAt(now, false) == false {
vErr.Inner = errors.New("Token is expired")
vErr.Errors |= ValidationErrorExpired
}
if m.VerifyIssuedAt(now, false) == false {
vErr.Inner = errors.New("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt
}
if m.VerifyNotBefore(now, false) == false {
vErr.Inner = errors.New("Token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
if vErr.valid() {
return nil
}
return vErr
}

@ -0,0 +1,52 @@
package jwt
// Implements the none signing method. This is required by the spec
// but you probably should never use it.
var SigningMethodNone *signingMethodNone
const UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed"
var NoneSignatureTypeDisallowedError error
type signingMethodNone struct{}
type unsafeNoneMagicConstant string
func init() {
SigningMethodNone = &signingMethodNone{}
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
return SigningMethodNone
})
}
func (m *signingMethodNone) Alg() string {
return "none"
}
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
// accepting 'none' signing method
if _, ok := key.(unsafeNoneMagicConstant); !ok {
return NoneSignatureTypeDisallowedError
}
// If signing method is none, signature must be an empty string
if signature != "" {
return NewValidationError(
"'none' signing method with non-empty signature",
ValidationErrorSignatureInvalid,
)
}
// Accept 'none' signing method.
return nil
}
// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
if _, ok := key.(unsafeNoneMagicConstant); ok {
return "", nil
}
return "", NoneSignatureTypeDisallowedError
}

@ -0,0 +1,148 @@
package jwt
import (
"bytes"
"encoding/json"
"fmt"
"strings"
)
type Parser struct {
ValidMethods []string // If populated, only these methods will be considered valid
UseJSONNumber bool // Use JSON Number format in JSON decoder
SkipClaimsValidation bool // Skip claims validation during token parsing
}
// Parse, validate, and return a token.
// keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
}
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
if err != nil {
return token, err
}
// Verify signing method is in the required set
if p.ValidMethods != nil {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range p.ValidMethods {
if m == alg {
signingMethodValid = true
break
}
}
if !signingMethodValid {
// signing method is not in the listed set
return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid)
}
}
// Lookup key
var key interface{}
if keyFunc == nil {
// keyFunc was not provided. short circuiting validation
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
}
if key, err = keyFunc(token); err != nil {
// keyFunc returned an error
if ve, ok := err.(*ValidationError); ok {
return token, ve
}
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
}
vErr := &ValidationError{}
// Validate Claims
if !p.SkipClaimsValidation {
if err := token.Claims.Valid(); err != nil {
// If the Claims Valid returned an error, check if it is a validation error,
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
if e, ok := err.(*ValidationError); !ok {
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
} else {
vErr = e
}
}
}
// Perform validation
token.Signature = parts[2]
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
vErr.Inner = err
vErr.Errors |= ValidationErrorSignatureInvalid
}
if vErr.valid() {
token.Valid = true
return token, nil
}
return token, vErr
}
// WARNING: Don't use this method unless you know what you're doing
//
// This method parses the token but doesn't validate the signature. It's only
// ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from
// it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
token = &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
}
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// parse Claims
var claimBytes []byte
token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.UseJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
if err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
}
} else {
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
}
return token, parts, nil
}

@ -0,0 +1,101 @@
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
)
// Implements the RSA family of signing methods signing methods
// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation
type SigningMethodRSA struct {
Name string
Hash crypto.Hash
}
// Specific instances for RS256 and company
var (
SigningMethodRS256 *SigningMethodRSA
SigningMethodRS384 *SigningMethodRSA
SigningMethodRS512 *SigningMethodRSA
)
func init() {
// RS256
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
return SigningMethodRS256
})
// RS384
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
return SigningMethodRS384
})
// RS512
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
return SigningMethodRS512
})
}
func (m *SigningMethodRSA) Alg() string {
return m.Name
}
// Implements the Verify method from SigningMethod
// For this signing method, must be an *rsa.PublicKey structure.
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
var rsaKey *rsa.PublicKey
var ok bool
if rsaKey, ok = key.(*rsa.PublicKey); !ok {
return ErrInvalidKeyType
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Verify the signature
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
}
// Implements the Sign method from SigningMethod
// For this signing method, must be an *rsa.PrivateKey structure.
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
var rsaKey *rsa.PrivateKey
var ok bool
// Validate type of key
if rsaKey, ok = key.(*rsa.PrivateKey); !ok {
return "", ErrInvalidKey
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
return EncodeSegment(sigBytes), nil
} else {
return "", err
}
}

@ -0,0 +1,126 @@
// +build go1.4
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
)
// Implements the RSAPSS family of signing methods signing methods
type SigningMethodRSAPSS struct {
*SigningMethodRSA
Options *rsa.PSSOptions
}
// Specific instances for RS/PS and company
var (
SigningMethodPS256 *SigningMethodRSAPSS
SigningMethodPS384 *SigningMethodRSAPSS
SigningMethodPS512 *SigningMethodRSAPSS
)
func init() {
// PS256
SigningMethodPS256 = &SigningMethodRSAPSS{
&SigningMethodRSA{
Name: "PS256",
Hash: crypto.SHA256,
},
&rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA256,
},
}
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
return SigningMethodPS256
})
// PS384
SigningMethodPS384 = &SigningMethodRSAPSS{
&SigningMethodRSA{
Name: "PS384",
Hash: crypto.SHA384,
},
&rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA384,
},
}
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
return SigningMethodPS384
})
// PS512
SigningMethodPS512 = &SigningMethodRSAPSS{
&SigningMethodRSA{
Name: "PS512",
Hash: crypto.SHA512,
},
&rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA512,
},
}
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
return SigningMethodPS512
})
}
// Implements the Verify method from SigningMethod
// For this verify method, key must be an rsa.PublicKey struct
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
var rsaKey *rsa.PublicKey
switch k := key.(type) {
case *rsa.PublicKey:
rsaKey = k
default:
return ErrInvalidKey
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, m.Options)
}
// Implements the Sign method from SigningMethod
// For this signing method, key must be an rsa.PrivateKey struct
func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) (string, error) {
var rsaKey *rsa.PrivateKey
switch k := key.(type) {
case *rsa.PrivateKey:
rsaKey = k
default:
return "", ErrInvalidKeyType
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil), m.Options); err == nil {
return EncodeSegment(sigBytes), nil
} else {
return "", err
}
}

@ -0,0 +1,101 @@
package jwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key")
ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key")
ErrNotRSAPublicKey = errors.New("Key is not a valid RSA public key")
)
// Parse PEM encoded PKCS1 or PKCS8 private key
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey interface{}
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 private key protected with password
func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey interface{}
var blockDecrypted []byte
if blockDecrypted, err = x509.DecryptPEMBlock(block, []byte(password)); err != nil {
return nil, err
}
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 public key
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
var pkey *rsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
return nil, ErrNotRSAPublicKey
}
return pkey, nil
}

@ -0,0 +1,35 @@
package jwt
import (
"sync"
)
var signingMethods = map[string]func() SigningMethod{}
var signingMethodLock = new(sync.RWMutex)
// Implement SigningMethod to add new methods for signing or verifying tokens.
type SigningMethod interface {
Verify(signingString, signature string, key interface{}) error // Returns nil if signature is valid
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
Alg() string // returns the alg identifier for this method (example: 'HS256')
}
// Register the "alg" name and a factory function for signing method.
// This is typically done during init() in the method's implementation
func RegisterSigningMethod(alg string, f func() SigningMethod) {
signingMethodLock.Lock()
defer signingMethodLock.Unlock()
signingMethods[alg] = f
}
// Get a signing method from an "alg" string
func GetSigningMethod(alg string) (method SigningMethod) {
signingMethodLock.RLock()
defer signingMethodLock.RUnlock()
if methodF, ok := signingMethods[alg]; ok {
method = methodF()
}
return
}

@ -0,0 +1,108 @@
package jwt
import (
"encoding/base64"
"encoding/json"
"strings"
"time"
)
// TimeFunc provides the current time when parsing token to validate "exp" claim (expiration time).
// You can override it to use another time value. This is useful for testing or if your
// server uses a different time zone than your tokens.
var TimeFunc = time.Now
// Parse methods use this callback function to supply
// the key for verification. The function receives the parsed,
// but unverified Token. This allows you to use properties in the
// Header of the token (such as `kid`) to identify which key to use.
type Keyfunc func(*Token) (interface{}, error)
// A JWT Token. Different fields will be used depending on whether you're
// creating or parsing/verifying a token.
type Token struct {
Raw string // The raw token. Populated when you Parse a token
Method SigningMethod // The signing method used or to be used
Header map[string]interface{} // The first segment of the token
Claims Claims // The second segment of the token
Signature string // The third segment of the token. Populated when you Parse a token
Valid bool // Is the token valid? Populated when you Parse/Verify a token
}
// Create a new Token. Takes a signing method
func New(method SigningMethod) *Token {
return NewWithClaims(method, MapClaims{})
}
func NewWithClaims(method SigningMethod, claims Claims) *Token {
return &Token{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
},
Claims: claims,
Method: method,
}
}
// Get the complete, signed token
func (t *Token) SignedString(key interface{}) (string, error) {
var sig, sstr string
var err error
if sstr, err = t.SigningString(); err != nil {
return "", err
}
if sig, err = t.Method.Sign(sstr, key); err != nil {
return "", err
}
return strings.Join([]string{sstr, sig}, "."), nil
}
// Generate the signing string. This is the
// most expensive part of the whole deal. Unless you
// need this for something special, just go straight for
// the SignedString.
func (t *Token) SigningString() (string, error) {
var err error
parts := make([]string, 2)
for i, _ := range parts {
var jsonValue []byte
if i == 0 {
if jsonValue, err = json.Marshal(t.Header); err != nil {
return "", err
}
} else {
if jsonValue, err = json.Marshal(t.Claims); err != nil {
return "", err
}
}
parts[i] = EncodeSegment(jsonValue)
}
return strings.Join(parts, "."), nil
}
// Parse, validate, and return a token.
// keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil
func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return new(Parser).Parse(tokenString, keyFunc)
}
func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
return new(Parser).ParseWithClaims(tokenString, claims, keyFunc)
}
// Encode JWT specific base64url encoding with padding stripped
func EncodeSegment(seg []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
}
// Decode JWT specific base64url encoding with padding stripped
func DecodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
return base64.URLEncoding.DecodeString(seg)
}

@ -0,0 +1,5 @@
gocql-fuzz
fuzz-corpus
fuzz-work
gocql.test
.idea

@ -0,0 +1,49 @@
language: go
sudo: required
dist: trusty
cache:
directories:
- $HOME/.ccm/repository
- $HOME/.local/lib/python2.7
matrix:
fast_finish: true
branches:
only:
- master
env:
global:
- GOMAXPROCS=2
matrix:
- CASS=2.1.21
AUTH=true
- CASS=2.2.14
AUTH=true
- CASS=2.2.14
AUTH=false
- CASS=3.0.18
AUTH=false
- CASS=3.11.4
AUTH=false
go:
- 1.12.x
- 1.13.x
install:
- ./install_test_deps.sh $TRAVIS_REPO_SLUG
- cd ../..
- cd gocql/gocql
- go get .
script:
- set -e
- PATH=$PATH:$HOME/.local/bin bash integration.sh $CASS $AUTH
- go vet .
notifications:
- email: false

@ -0,0 +1,114 @@
# This source file refers to The gocql Authors for copyright purposes.
Christoph Hack <christoph@tux21b.org>
Jonathan Rudenberg <jonathan@titanous.com>
Thorsten von Eicken <tve@rightscale.com>
Matt Robenolt <mattr@disqus.com>
Phillip Couto <phillip.couto@stemstudios.com>
Niklas Korz <korz.niklask@gmail.com>
Nimi Wariboko Jr <nimi@channelmeter.com>
Ghais Issa <ghais.issa@gmail.com>
Sasha Klizhentas <klizhentas@gmail.com>
Konstantin Cherkasov <k.cherkasoff@gmail.com>
Ben Hood <0x6e6562@gmail.com>
Pete Hopkins <phopkins@gmail.com>
Chris Bannister <c.bannister@gmail.com>
Maxim Bublis <b@codemonkey.ru>
Alex Zorin <git@zor.io>
Kasper Middelboe Petersen <me@phant.dk>
Harpreet Sawhney <harpreet.sawhney@gmail.com>
Charlie Andrews <charlieandrews.cwa@gmail.com>
Stanislavs Koikovs <stanislavs.koikovs@gmail.com>
Dan Forest <bonjour@dan.tf>
Miguel Serrano <miguelvps@gmail.com>
Stefan Radomski <gibheer@zero-knowledge.org>
Josh Wright <jshwright@gmail.com>
Jacob Rhoden <jacob.rhoden@gmail.com>
Ben Frye <benfrye@gmail.com>
Fred McCann <fred@sharpnoodles.com>
Dan Simmons <dan@simmons.io>
Muir Manders <muir@retailnext.net>
Sankar P <sankar.curiosity@gmail.com>
Julien Da Silva <julien.dasilva@gmail.com>
Dan Kennedy <daniel@firstcs.co.uk>
Nick Dhupia<nick.dhupia@gmail.com>
Yasuharu Goto <matope.ono@gmail.com>
Jeremy Schlatter <jeremy.schlatter@gmail.com>
Matthias Kadenbach <matthias.kadenbach@gmail.com>
Dean Elbaz <elbaz.dean@gmail.com>
Mike Berman <evencode@gmail.com>
Dmitriy Fedorenko <c0va23@gmail.com>
Zach Marcantel <zmarcantel@gmail.com>
James Maloney <jamessagan@gmail.com>
Ashwin Purohit <purohit@gmail.com>
Dan Kinder <dkinder.is.me@gmail.com>
Oliver Beattie <oliver@obeattie.com>
Justin Corpron <jncorpron@gmail.com>
Miles Delahunty <miles.delahunty@gmail.com>
Zach Badgett <zach.badgett@gmail.com>
Maciek Sakrejda <maciek@heroku.com>
Jeff Mitchell <jeffrey.mitchell@gmail.com>
Baptiste Fontaine <b@ptistefontaine.fr>
Matt Heath <matt@mattheath.com>
Jamie Cuthill <jamie.cuthill@gmail.com>
Adrian Casajus <adriancasajus@gmail.com>
John Weldon <johnweldon4@gmail.com>
Adrien Bustany <adrien@bustany.org>
Andrey Smirnov <smirnov.andrey@gmail.com>
Adam Weiner <adamsweiner@gmail.com>
Daniel Cannon <daniel@danielcannon.co.uk>
Johnny Bergström <johnny@joonix.se>
Adriano Orioli <orioli.adriano@gmail.com>
Claudiu Raveica <claudiu.raveica@gmail.com>
Artem Chernyshev <artem.0xD2@gmail.com>
Ference Fu <fym201@msn.com>
LOVOO <opensource@lovoo.com>
nikandfor <nikandfor@gmail.com>
Anthony Woods <awoods@raintank.io>
Alexander Inozemtsev <alexander.inozemtsev@gmail.com>
Rob McColl <rob@robmccoll.com>; <rmccoll@ionicsecurity.com>
Viktor Tönköl <viktor.toenkoel@motionlogic.de>
Ian Lozinski <ian.lozinski@gmail.com>
Michael Highstead <highstead@gmail.com>
Sarah Brown <esbie.is@gmail.com>
Caleb Doxsey <caleb@datadoghq.com>
Frederic Hemery <frederic.hemery@datadoghq.com>
Pekka Enberg <penberg@scylladb.com>
Mark M <m.mim95@gmail.com>
Bartosz Burclaf <burclaf@gmail.com>
Marcus King <marcusking01@gmail.com>
Andrew de Andrade <andrew@deandrade.com.br>
Robert Nix <robert@nicerobot.org>
Nathan Youngman <git@nathany.com>
Charles Law <charles.law@gmail.com>; <claw@conduce.com>
Nathan Davies <nathanjamesdavies@gmail.com>
Bo Blanton <bo.blanton@gmail.com>
Vincent Rischmann <me@vrischmann.me>
Jesse Claven <jesse.claven@gmail.com>
Derrick Wippler <thrawn01@gmail.com>
Leigh McCulloch <leigh@leighmcculloch.com>
Ron Kuris <swcafe@gmail.com>
Raphael Gavache <raphael.gavache@gmail.com>
Yasser Abdolmaleki <yasser@yasser.ca>
Krishnanand Thommandra <devtkrishna@gmail.com>
Blake Atkinson <me@blakeatkinson.com>
Dharmendra Parsaila <d4dharmu@gmail.com>
Nayef Ghattas <nayef.ghattas@datadoghq.com>
Michał Matczuk <mmatczuk@gmail.com>
Ben Krebsbach <ben.krebsbach@gmail.com>
Vivian Mathews <vivian.mathews.3@gmail.com>
Sascha Steinbiss <satta@debian.org>
Seth Rosenblum <seth.t.rosenblum@gmail.com>
Javier Zunzunegui <javier.zunzunegui.b@gmail.com>
Luke Hines <lukehines@protonmail.com>
Zhixin Wen <john.wenzhixin@hotmail.com>
Chang Liu <changliu.it@gmail.com>
Ingo Oeser <nightlyone@gmail.com>
Luke Hines <lukehines@protonmail.com>
Jacob Greenleaf <jacob@jacobgreenleaf.com>
Alex Lourie <alex@instaclustr.com>; <djay.il@gmail.com>
Marco Cadetg <cadetg@gmail.com>
Karl Matthias <karl@matthias.org>
Thomas Meson <zllak@hycik.org>
Martin Sucha <martin.sucha@kiwi.com>; <git@mm.ms47.eu>
Pavel Buchinchik <p.buchinchik@gmail.com>

@ -0,0 +1,78 @@
# Contributing to gocql
**TL;DR** - this manifesto sets out the bare minimum requirements for submitting a patch to gocql.
This guide outlines the process of landing patches in gocql and the general approach to maintaining the code base.
## Background
The goal of the gocql project is to provide a stable and robust CQL driver for Go. gocql is a community driven project that is coordinated by a small team of core developers.
## Minimum Requirement Checklist
The following is a check list of requirements that need to be satisfied in order for us to merge your patch:
* You should raise a pull request to gocql/gocql on Github
* The pull request has a title that clearly summarizes the purpose of the patch
* The motivation behind the patch is clearly defined in the pull request summary
* Your name and email have been added to the `AUTHORS` file (for copyright purposes)
* The patch will merge cleanly
* The test coverage does not fall below the critical threshold (currently 64%)
* The merge commit passes the regression test suite on Travis
* `go fmt` has been applied to the submitted code
* Functional changes (i.e. new features or changed behavior) are appropriately documented, either as a godoc or in the README (non-functional changes such as bug fixes may not require documentation)
If there are any requirements that can't be reasonably satisfied, please state this either on the pull request or as part of discussion on the mailing list. Where appropriate, the core team may apply discretion and make an exception to these requirements.
## Beyond The Checklist
In addition to stating the hard requirements, there are a bunch of things that we consider when assessing changes to the library. These soft requirements are helpful pointers of how to get a patch landed quicker and with less fuss.
### General QA Approach
The gocql team needs to consider the ongoing maintainability of the library at all times. Patches that look like they will introduce maintenance issues for the team will not be accepted.
Your patch will get merged quicker if you have decent test cases that provide test coverage for the new behavior you wish to introduce.
Unit tests are good, integration tests are even better. An example of a unit test is `marshal_test.go` - this tests the serialization code in isolation. `cassandra_test.go` is an integration test suite that is executed against every version of Cassandra that gocql supports as part of the CI process on Travis.
That said, the point of writing tests is to provide a safety net to catch regressions, so there is no need to go overboard with tests. Remember that the more tests you write, the more code we will have to maintain. So there's a balance to strike there.
### When It's Too Difficult To Automate Testing
There are legitimate examples of where it is infeasible to write a regression test for a change. Never fear, we will still consider the patch and quite possibly accept the change without a test. The gocql team takes a pragmatic approach to testing. At the end of the day, you could be addressing an issue that is too difficult to reproduce in a test suite, but still occurs in a real production app. In this case, your production app is the test case, and we will have to trust that your change is good.
Examples of pull requests that have been accepted without tests include:
* https://github.com/gocql/gocql/pull/181 - this patch would otherwise require a multi-node cluster to be booted as part of the CI build
* https://github.com/gocql/gocql/pull/179 - this bug can only be reproduced under heavy load in certain circumstances
### Sign Off Procedure
Generally speaking, a pull request can get merged by any one of the core gocql team. If your change is minor, chances are that one team member will just go ahead and merge it there and then. As stated earlier, suitable test coverage will increase the likelihood that a single reviewer will assess and merge your change. If your change has no test coverage, or looks like it may have wider implications for the health and stability of the library, the reviewer may elect to refer the change to another team member to achieve consensus before proceeding. Therefore, the tighter and cleaner your patch is, the quicker it will go through the review process.
### Supported Features
gocql is a low level wire driver for Cassandra CQL. By and large, we would like to keep the functional scope of the library as narrow as possible. We think that gocql should be tight and focused, and we will be naturally skeptical of things that could just as easily be implemented in a higher layer. Inevitably you will come across something that could be implemented in a higher layer, save for a minor change to the core API. In this instance, please strike up a conversation with the gocql team. Chances are we will understand what you are trying to achieve and will try to accommodate this in a maintainable way.
### Longer Term Evolution
There are some long term plans for gocql that have to be taken into account when assessing changes. That said, gocql is ultimately a community driven project and we don't have a massive development budget, so sometimes the long term view might need to be de-prioritized ahead of short term changes.
## Officially Supported Server Versions
Currently, the officially supported versions of the Cassandra server include:
* 1.2.18
* 2.0.9
Chances are that gocql will work with many other versions. If you would like us to support a particular version of Cassandra, please start a conversation about what version you'd like us to consider. We are more likely to accept a new version if you help out by extending the regression suite to cover the new version to be supported.
## The Core Dev Team
The core development team includes:
* tux21b
* phillipCouto
* Zariel
* 0x6e6562

@ -0,0 +1,27 @@
Copyright (c) 2016, The Gocql authors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,215 @@
gocql
=====
[![Join the chat at https://gitter.im/gocql/gocql](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/gocql/gocql?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Build Status](https://travis-ci.org/gocql/gocql.svg?branch=master)](https://travis-ci.org/gocql/gocql)
[![GoDoc](https://godoc.org/github.com/gocql/gocql?status.svg)](https://godoc.org/github.com/gocql/gocql)
Package gocql implements a fast and robust Cassandra client for the
Go programming language.
Project Website: https://gocql.github.io/<br>
API documentation: https://godoc.org/github.com/gocql/gocql<br>
Discussions: https://groups.google.com/forum/#!forum/gocql
Supported Versions
------------------
The following matrix shows the versions of Go and Cassandra that are tested with the integration test suite as part of the CI build:
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
-------------| -------| ------| ---------
1.12 | yes | yes | yes
1.13 | yes | yes | yes
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
Sunsetting Model
----------------
In general, the gocql team will focus on supporting the current and previous versions of Go. gocql may still work with older versions of Go, but official support for these versions will have been sunset.
Installation
------------
go get github.com/gocql/gocql
Features
--------
* Modern Cassandra client using the native transport
* Automatic type conversions between Cassandra and Go
* Support for all common types including sets, lists and maps
* Custom types can implement a `Marshaler` and `Unmarshaler` interface
* Strict type conversions without any loss of precision
* Built-In support for UUIDs (version 1 and 4)
* Support for logged, unlogged and counter batches
* Cluster management
* Automatic reconnect on connection failures with exponential falloff
* Round robin distribution of queries to different hosts
* Round robin distribution of queries to different connections on a host
* Each connection can execute up to n concurrent queries (whereby n is the limit set by the protocol version the client chooses to use)
* Optional automatic discovery of nodes
* Policy based connection pool with token aware and round-robin policy implementations
* Support for password authentication
* Iteration over paged results with configurable page size
* Support for TLS/SSL
* Optional frame compression (using snappy)
* Automatic query preparation
* Support for query tracing
* Support for Cassandra 2.1+ [binary protocol version 3](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v3.spec)
* Support for up to 32768 streams
* Support for tuple types
* Support for client side timestamps by default
* Support for UDTs via a custom marshaller or struct tags
* Support for Cassandra 3.0+ [binary protocol version 4](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec)
* An API to access the schema metadata of a given keyspace
Performance
-----------
While the driver strives to be highly performant, there are cases where it is difficult to test and verify. The driver is built
with maintainability and code readability in mind first and then performance and features, as such every now and then performance
may degrade, if this occurs please report and issue and it will be looked at and remedied. The only time the driver copies data from
its read buffer is when it Unmarshal's data into supplied types.
Some tips for getting more performance from the driver:
* Use the TokenAware policy
* Use many goroutines when doing inserts, the driver is asynchronous but provides a synchronous API, it can execute many queries concurrently
* Tune query page size
* Reading data from the network to unmarshal will incur a large amount of allocations, this can adversely affect the garbage collector, tune `GOGC`
* Close iterators after use to recycle byte buffers
Important Default Keyspace Changes
----------------------------------
gocql no longer supports executing "use <keyspace>" statements to simplify the library. The user still has the
ability to define the default keyspace for connections but now the keyspace can only be defined before a
session is created. Queries can still access keyspaces by indicating the keyspace in the query:
```sql
SELECT * FROM example2.table;
```
Example of correct usage:
```go
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
cluster.Keyspace = "example"
...
session, err := cluster.CreateSession()
```
Example of incorrect usage:
```go
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
cluster.Keyspace = "example"
...
session, err := cluster.CreateSession()
if err = session.Query("use example2").Exec(); err != nil {
log.Fatal(err)
}
```
This will result in an err being returned from the session.Query line as the user is trying to execute a "use"
statement.
Example
-------
```go
/* Before you execute the program, Launch `cqlsh` and execute:
create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id));
create index on example.tweet(timeline);
*/
package main
import (
"fmt"
"log"
"github.com/gocql/gocql"
)
func main() {
// connect to the cluster
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
cluster.Keyspace = "example"
cluster.Consistency = gocql.Quorum
session, _ := cluster.CreateSession()
defer session.Close()
// insert a tweet
if err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
"me", gocql.TimeUUID(), "hello world").Exec(); err != nil {
log.Fatal(err)
}
var id gocql.UUID
var text string
/* Search for a specific set of records whose 'timeline' column matches
* the value 'me'. The secondary index that we created earlier will be
* used for optimizing the search */
if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
"me").Consistency(gocql.One).Scan(&id, &text); err != nil {
log.Fatal(err)
}
fmt.Println("Tweet:", id, text)
// list all tweets
iter := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, "me").Iter()
for iter.Scan(&id, &text) {
fmt.Println("Tweet:", id, text)
}
if err := iter.Close(); err != nil {
log.Fatal(err)
}
}
```
Data Binding
------------
There are various ways to bind application level data structures to CQL statements:
* You can write the data binding by hand, as outlined in the Tweet example. This provides you with the greatest flexibility, but it does mean that you need to keep your application code in sync with your Cassandra schema.
* You can dynamically marshal an entire query result into an `[]map[string]interface{}` using the `SliceMap()` API. This returns a slice of row maps keyed by CQL column names. This method requires no special interaction with the gocql API, but it does require your application to be able to deal with a key value view of your data.
* As a refinement on the `SliceMap()` API you can also call `MapScan()` which returns `map[string]interface{}` instances in a row by row fashion.
* The `Bind()` API provides a client app with a low level mechanism to introspect query meta data and extract appropriate field values from application level data structures.
* The [gocqlx](https://github.com/scylladb/gocqlx) package is an idiomatic extension to gocql that provides usability features. With gocqlx you can bind the query parameters from maps and structs, use named query parameters (:identifier) and scan the query results into structs and slices. It comes with a fluent and flexible CQL query builder that supports full CQL spec, including BATCH statements and custom functions.
* Building on top of the gocql driver, [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
* Another external project that layers on top of gocql is [cqlc](http://relops.com/cqlc) which generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
* [gocassa](https://github.com/hailocab/gocassa) is an external project that layers on top of gocql to provide convenient query building and data binding.
* [gocqltable](https://github.com/kristoiv/gocqltable) provides an ORM-style convenience layer to make CRUD operations with gocql easier.
Ecosystem
---------
The following community maintained tools are known to integrate with gocql:
* [gocqlx](https://github.com/scylladb/gocqlx) is a gocql extension that automates data binding, adds named queries support, provides flexible query builders and plays well with gocql.
* [journey](https://github.com/db-journey/journey) is a migration tool with Cassandra support.
* [negronicql](https://github.com/mikebthun/negronicql) is gocql middleware for Negroni.
* [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
* [cqlc](http://relops.com/cqlc) generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
* [gocassa](https://github.com/hailocab/gocassa) provides query building, adds data binding, and provides easy-to-use "recipe" tables for common query use-cases.
* [gocqltable](https://github.com/kristoiv/gocqltable) is a wrapper around gocql that aims to simplify common operations.
* [gockle](https://github.com/willfaught/gockle) provides simple, mockable interfaces that wrap gocql types
* [scylladb](https://github.com/scylladb/scylla) is a fast Apache Cassandra-compatible NoSQL database
* [go-cql-driver](https://github.com/MichaelS11/go-cql-driver) is an CQL driver conforming to the built-in database/sql interface. It is good for simple use cases where the database/sql interface is wanted. The CQL driver is a wrapper around this project.
Other Projects
--------------
* [gocqldriver](https://github.com/tux21b/gocqldriver) is the predecessor of gocql based on Go's `database/sql` package. This project isn't maintained anymore, because Cassandra wasn't a good fit for the traditional `database/sql` API. Use this package instead.
SEO
---
For some reason, when you Google `golang cassandra`, this project doesn't feature very highly in the result list. But if you Google `go cassandra`, then we're a bit higher up the list. So this is note to try to convince Google that golang is an alias for Go.
License
-------
> Copyright (c) 2012-2016 The gocql Authors. All rights reserved.
> Use of this source code is governed by a BSD-style
> license that can be found in the LICENSE file.

@ -0,0 +1,26 @@
package gocql
import "net"
// AddressTranslator provides a way to translate node addresses (and ports) that are
// discovered or received as a node event. This can be useful in an ec2 environment,
// for instance, to translate public IPs to private IPs.
type AddressTranslator interface {
// Translate will translate the provided address and/or port to another
// address and/or port. If no translation is possible, Translate will return the
// address and port provided to it.
Translate(addr net.IP, port int) (net.IP, int)
}
type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)
func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
return fn(addr, port)
}
// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
func IdentityTranslator() AddressTranslator {
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
return addr, port
})
}

@ -0,0 +1,211 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"errors"
"net"
"time"
)
// PoolConfig configures the connection pool used by the driver, it defaults to
// using a round-robin host selection policy and a round-robin connection selection
// policy for each host.
type PoolConfig struct {
// HostSelectionPolicy sets the policy for selecting which host to use for a
// given query (default: RoundRobinHostPolicy())
HostSelectionPolicy HostSelectionPolicy
}
func (p PoolConfig) buildPool(session *Session) *policyConnPool {
return newPolicyConnPool(session)
}
// ClusterConfig is a struct to configure the default cluster implementation
// of gocql. It has a variety of attributes that can be used to modify the
// behavior to fit the most common use cases. Applications that require a
// different setup must implement their own cluster.
type ClusterConfig struct {
// addresses for the initial connections. It is recommended to use the value set in
// the Cassandra config for broadcast_address or listen_address, an IP address not
// a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified
// resolves to more than 1 IP address then the driver may connect multiple times to
// the same host, and will not mark the node being down or up from events.
Hosts []string
CQLVersion string // CQL version (default: 3.0.0)
// ProtoVersion sets the version of the native protocol to use, this will
// enable features in the driver for specific protocol versions, generally this
// should be set to a known version (2,3,4) for the cluster being connected to.
//
// If it is 0 or unset (the default) then the driver will attempt to discover the
// highest supported protocol for the cluster. In clusters with nodes of different
// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
ProtoVersion int
Timeout time.Duration // connection timeout (default: 600ms)
ConnectTimeout time.Duration // initial connection timeout, used during initial dial to server (default: 600ms)
Port int // port (default: 9042)
Keyspace string // initial keyspace (optional)
NumConns int // number of connections per host (default: 2)
Consistency Consistency // default consistency level (default: Quorum)
Compressor Compressor // compression algorithm (default: nil)
Authenticator Authenticator // authenticator (default: nil)
AuthProvider func(h *HostInfo) (Authenticator, error) // an authenticator factory. Can be used to create alternative authenticators (default: nil)
RetryPolicy RetryPolicy // Default retry policy to use for queries (default: 0)
ConvictionPolicy ConvictionPolicy // Decide whether to mark host as down based on the error and host info (default: SimpleConvictionPolicy)
ReconnectionPolicy ReconnectionPolicy // Default reconnection policy to use for reconnecting before trying to mark host as down (default: see below)
SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0)
MaxPreparedStmts int // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
MaxRoutingKeyInfo int // Sets the maximum cache size for query info about statements for each session (default: 1000)
PageSize int // Default page size to use for created sessions (default: 5000)
SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset)
SslOpts *SslOptions
DefaultTimestamp bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
// PoolConfig configures the underlying connection pool, allowing the
// configuration of host selection and connection selection policies.
PoolConfig PoolConfig
// If not zero, gocql attempt to reconnect known DOWN nodes in every ReconnectInterval.
ReconnectInterval time.Duration
// The maximum amount of time to wait for schema agreement in a cluster after
// receiving a schema change frame. (default: 60s)
MaxWaitSchemaAgreement time.Duration
// HostFilter will filter all incoming events for host, any which don't pass
// the filter will be ignored. If set will take precedence over any options set
// via Discovery
HostFilter HostFilter
// AddressTranslator will translate addresses found on peer discovery and/or
// node change events.
AddressTranslator AddressTranslator
// If IgnorePeerAddr is true and the address in system.peers does not match
// the supplied host by either initial hosts or discovered via events then the
// host will be replaced with the supplied address.
//
// For example if an event comes in with host=10.0.0.1 but when looking up that
// address in system.local or system.peers returns 127.0.0.1, the peer will be
// set to 10.0.0.1 which is what will be used to connect to.
IgnorePeerAddr bool
// If DisableInitialHostLookup then the driver will not attempt to get host info
// from the system.peers table, this will mean that the driver will connect to
// hosts supplied and will not attempt to lookup the hosts information, this will
// mean that data_centre, rack and token information will not be available and as
// such host filtering and token aware query routing will not be available.
DisableInitialHostLookup bool
// Configure events the driver will register for
Events struct {
// disable registering for status events (node up/down)
DisableNodeStatusEvents bool
// disable registering for topology events (node added/removed/moved)
DisableTopologyEvents bool
// disable registering for schema events (keyspace/table/function removed/created/updated)
DisableSchemaEvents bool
}
// DisableSkipMetadata will override the internal result metadata cache so that the driver does not
// send skip_metadata for queries, this means that the result will always contain
// the metadata to parse the rows and will not reuse the metadata from the prepared
// statement.
//
// See https://issues.apache.org/jira/browse/CASSANDRA-10786
DisableSkipMetadata bool
// QueryObserver will set the provided query observer on all queries created from this session.
// Use it to collect metrics / stats from queries by providing an implementation of QueryObserver.
QueryObserver QueryObserver
// BatchObserver will set the provided batch observer on all queries created from this session.
// Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver.
BatchObserver BatchObserver
// ConnectObserver will set the provided connect observer on all queries
// created from this session.
ConnectObserver ConnectObserver
// FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session.
// Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver.
FrameHeaderObserver FrameHeaderObserver
// Default idempotence for queries
DefaultIdempotence bool
// The time to wait for frames before flushing the frames connection to Cassandra.
// Can help reduce syscall overhead by making less calls to write. Set to 0 to
// disable.
//
// (default: 200 microseconds)
WriteCoalesceWaitTime time.Duration
// internal config for testing
disableControlConn bool
}
// NewCluster generates a new config for the default cluster implementation.
//
// The supplied hosts are used to initially connect to the cluster then the rest of
// the ring will be automatically discovered. It is recommended to use the value set in
// the Cassandra config for broadcast_address or listen_address, an IP address not
// a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified
// resolves to more than 1 IP address then the driver may connect multiple times to
// the same host, and will not mark the node being down or up from events.
func NewCluster(hosts ...string) *ClusterConfig {
cfg := &ClusterConfig{
Hosts: hosts,
CQLVersion: "3.0.0",
Timeout: 600 * time.Millisecond,
ConnectTimeout: 600 * time.Millisecond,
Port: 9042,
NumConns: 2,
Consistency: Quorum,
MaxPreparedStmts: defaultMaxPreparedStmts,
MaxRoutingKeyInfo: 1000,
PageSize: 5000,
DefaultTimestamp: true,
MaxWaitSchemaAgreement: 60 * time.Second,
ReconnectInterval: 60 * time.Second,
ConvictionPolicy: &SimpleConvictionPolicy{},
ReconnectionPolicy: &ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second},
WriteCoalesceWaitTime: 200 * time.Microsecond,
}
return cfg
}
// CreateSession initializes the cluster based on this config and returns a
// session object that can be used to interact with the database.
func (cfg *ClusterConfig) CreateSession() (*Session, error) {
return NewSession(*cfg)
}
// translateAddressPort is a helper method that will use the given AddressTranslator
// if defined, to translate the given address and port into a possibly new address
// and port, If no AddressTranslator or if an error occurs, the given address and
// port will be returned.
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
if cfg.AddressTranslator == nil || len(addr) == 0 {
return addr, port
}
newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
if gocqlDebug {
Logger.Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
}
return newAddr, newPort
}
func (cfg *ClusterConfig) filterHost(host *HostInfo) bool {
return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host))
}
var (
ErrNoHosts = errors.New("no hosts provided")
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
ErrHostQueryFailed = errors.New("unable to populate Hosts")
)

@ -0,0 +1,28 @@
package gocql
import (
"github.com/golang/snappy"
)
type Compressor interface {
Name() string
Encode(data []byte) ([]byte, error)
Decode(data []byte) ([]byte, error)
}
// SnappyCompressor implements the Compressor interface and can be used to
// compress incoming and outgoing frames. The snappy compression algorithm
// aims for very high speeds and reasonable compression.
type SnappyCompressor struct{}
func (s SnappyCompressor) Name() string {
return "snappy"
}
func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
return snappy.Encode(nil, data), nil
}
func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
return snappy.Decode(nil, data)
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,581 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
)
// interface to implement to receive the host information
type SetHosts interface {
SetHosts(hosts []*HostInfo)
}
// interface to implement to receive the partitioner value
type SetPartitioner interface {
SetPartitioner(partitioner string)
}
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
if sslOpts.Config == nil {
sslOpts.Config = &tls.Config{}
}
// ca cert is optional
if sslOpts.CaPath != "" {
if sslOpts.RootCAs == nil {
sslOpts.RootCAs = x509.NewCertPool()
}
pem, err := ioutil.ReadFile(sslOpts.CaPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
}
if !sslOpts.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("connectionpool: failed parsing or CA certs")
}
}
if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
}
sslOpts.Certificates = append(sslOpts.Certificates, mycert)
}
sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification
// return clone to avoid race
return sslOpts.Config.Clone(), nil
}
type policyConnPool struct {
session *Session
port int
numConns int
keyspace string
mu sync.RWMutex
hostConnPools map[string]*hostConnPool
endpoints []string
}
func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
var (
err error
tlsConfig *tls.Config
)
// TODO(zariel): move tls config setup into session init.
if cfg.SslOpts != nil {
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
}
}
return &ConnConfig{
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
ConnectTimeout: cfg.ConnectTimeout,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
tlsConfig: tlsConfig,
disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS.
}, nil
}
func newPolicyConnPool(session *Session) *policyConnPool {
// create the pool
pool := &policyConnPool{
session: session,
port: session.cfg.Port,
numConns: session.cfg.NumConns,
keyspace: session.cfg.Keyspace,
hostConnPools: map[string]*hostConnPool{},
}
pool.endpoints = make([]string, len(session.cfg.Hosts))
copy(pool.endpoints, session.cfg.Hosts)
return pool
}
func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
p.mu.Lock()
defer p.mu.Unlock()
toRemove := make(map[string]struct{})
for addr := range p.hostConnPools {
toRemove[addr] = struct{}{}
}
pools := make(chan *hostConnPool)
createCount := 0
for _, host := range hosts {
if !host.IsUp() {
// don't create a connection pool for a down host
continue
}
ip := host.ConnectAddress().String()
if _, exists := p.hostConnPools[ip]; exists {
// still have this host, so don't remove it
delete(toRemove, ip)
continue
}
createCount++
go func(host *HostInfo) {
// create a connection pool for the host
pools <- newHostConnPool(
p.session,
host,
p.port,
p.numConns,
p.keyspace,
)
}(host)
}
// add created pools
for createCount > 0 {
pool := <-pools
createCount--
if pool.Size() > 0 {
// add pool only if there a connections available
p.hostConnPools[string(pool.host.ConnectAddress())] = pool
}
}
for addr := range toRemove {
pool := p.hostConnPools[addr]
delete(p.hostConnPools, addr)
go pool.Close()
}
}
func (p *policyConnPool) Size() int {
p.mu.RLock()
count := 0
for _, pool := range p.hostConnPools {
count += pool.Size()
}
p.mu.RUnlock()
return count
}
func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
ip := host.ConnectAddress().String()
p.mu.RLock()
pool, ok = p.hostConnPools[ip]
p.mu.RUnlock()
return
}
func (p *policyConnPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
// close the pools
for addr, pool := range p.hostConnPools {
delete(p.hostConnPools, addr)
pool.Close()
}
}
func (p *policyConnPool) addHost(host *HostInfo) {
ip := host.ConnectAddress().String()
p.mu.Lock()
pool, ok := p.hostConnPools[ip]
if !ok {
pool = newHostConnPool(
p.session,
host,
host.Port(), // TODO: if port == 0 use pool.port?
p.numConns,
p.keyspace,
)
p.hostConnPools[ip] = pool
}
p.mu.Unlock()
pool.fill()
}
func (p *policyConnPool) removeHost(ip net.IP) {
k := ip.String()
p.mu.Lock()
pool, ok := p.hostConnPools[k]
if !ok {
p.mu.Unlock()
return
}
delete(p.hostConnPools, k)
p.mu.Unlock()
go pool.Close()
}
func (p *policyConnPool) hostUp(host *HostInfo) {
// TODO(zariel): have a set of up hosts and down hosts, we can internally
// detect down hosts, then try to reconnect to them.
p.addHost(host)
}
func (p *policyConnPool) hostDown(ip net.IP) {
// TODO(zariel): mark host as down so we can try to connect to it later, for
// now just treat it has removed.
p.removeHost(ip)
}
// hostConnPool is a connection pool for a single host.
// Connection selection is based on a provided ConnSelectionPolicy
type hostConnPool struct {
session *Session
host *HostInfo
port int
addr string
size int
keyspace string
// protection for conns, closed, filling
mu sync.RWMutex
conns []*Conn
closed bool
filling bool
pos uint32
}
func (h *hostConnPool) String() string {
h.mu.RLock()
defer h.mu.RUnlock()
return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
h.filling, h.closed, len(h.conns), h.size, h.host)
}
func newHostConnPool(session *Session, host *HostInfo, port, size int,
keyspace string) *hostConnPool {
pool := &hostConnPool{
session: session,
host: host,
port: port,
addr: (&net.TCPAddr{IP: host.ConnectAddress(), Port: host.Port()}).String(),
size: size,
keyspace: keyspace,
conns: make([]*Conn, 0, size),
filling: false,
closed: false,
}
// the pool is not filled or connected
return pool
}
// Pick a connection from this connection pool for the given query.
func (pool *hostConnPool) Pick() *Conn {
pool.mu.RLock()
defer pool.mu.RUnlock()
if pool.closed {
return nil
}
size := len(pool.conns)
if size < pool.size {
// try to fill the pool
go pool.fill()
if size == 0 {
return nil
}
}
pos := int(atomic.AddUint32(&pool.pos, 1) - 1)
var (
leastBusyConn *Conn
streamsAvailable int
)
// find the conn which has the most available streams, this is racy
for i := 0; i < size; i++ {
conn := pool.conns[(pos+i)%size]
if streams := conn.AvailableStreams(); streams > streamsAvailable {
leastBusyConn = conn
streamsAvailable = streams
}
}
return leastBusyConn
}
//Size returns the number of connections currently active in the pool
func (pool *hostConnPool) Size() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
return len(pool.conns)
}
//Close the connection pool
func (pool *hostConnPool) Close() {
pool.mu.Lock()
if pool.closed {
pool.mu.Unlock()
return
}
pool.closed = true
// ensure we dont try to reacquire the lock in handleError
// TODO: improve this as the following can happen
// 1) we have locked pool.mu write lock
// 2) conn.Close calls conn.closeWithError(nil)
// 3) conn.closeWithError calls conn.Close() which returns an error
// 4) conn.closeWithError calls pool.HandleError with the error from conn.Close
// 5) pool.HandleError tries to lock pool.mu
// deadlock
// empty the pool
conns := pool.conns
pool.conns = nil
pool.mu.Unlock()
// close the connections
for _, conn := range conns {
conn.Close()
}
}
// Fill the connection pool
func (pool *hostConnPool) fill() {
pool.mu.RLock()
// avoid filling a closed pool, or concurrent filling
if pool.closed || pool.filling {
pool.mu.RUnlock()
return
}
// determine the filling work to be done
startCount := len(pool.conns)
fillCount := pool.size - startCount
// avoid filling a full (or overfull) pool
if fillCount <= 0 {
pool.mu.RUnlock()
return
}
// switch from read to write lock
pool.mu.RUnlock()
pool.mu.Lock()
// double check everything since the lock was released
startCount = len(pool.conns)
fillCount = pool.size - startCount
if pool.closed || pool.filling || fillCount <= 0 {
// looks like another goroutine already beat this
// goroutine to the filling
pool.mu.Unlock()
return
}
// ok fill the pool
pool.filling = true
// allow others to access the pool while filling
pool.mu.Unlock()
// only this goroutine should make calls to fill/empty the pool at this
// point until after this routine or its subordinates calls
// fillingStopped
// fill only the first connection synchronously
if startCount == 0 {
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
// probably unreachable host
pool.fillingStopped(true)
// this is call with the connection pool mutex held, this call will
// then recursively try to lock it again. FIXME
if pool.session.cfg.ConvictionPolicy.AddFailure(err, pool.host) {
go pool.session.handleNodeDown(pool.host.ConnectAddress(), pool.port)
}
return
}
// filled one
fillCount--
}
// fill the rest of the pool asynchronously
go func() {
err := pool.connectMany(fillCount)
// mark the end of filling
pool.fillingStopped(err != nil)
}()
}
func (pool *hostConnPool) logConnectErr(err error) {
if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
// connection refused
// these are typical during a node outage so avoid log spam.
if gocqlDebug {
Logger.Printf("unable to dial %q: %v\n", pool.host.ConnectAddress(), err)
}
} else if err != nil {
// unexpected error
Logger.Printf("error: failed to connect to %s due to error: %v", pool.addr, err)
}
}
// transition back to a not-filling state.
func (pool *hostConnPool) fillingStopped(hadError bool) {
if hadError {
// wait for some time to avoid back-to-back filling
// this provides some time between failed attempts
// to fill the pool for the host to recover
time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
}
pool.mu.Lock()
pool.filling = false
pool.mu.Unlock()
}
// connectMany creates new connections concurrent.
func (pool *hostConnPool) connectMany(count int) error {
if count == 0 {
return nil
}
var (
wg sync.WaitGroup
mu sync.Mutex
connectErr error
)
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
mu.Lock()
connectErr = err
mu.Unlock()
}
}()
}
// wait for all connections are done
wg.Wait()
return connectErr
}
// create a new connection to the host and add it to the pool
func (pool *hostConnPool) connect() (err error) {
// TODO: provide a more robust connection retry mechanism, we should also
// be able to detect hosts that come up by trying to connect to downed ones.
// try to connect
var conn *Conn
reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
conn, err = pool.session.connect(pool.session.ctx, pool.host, pool)
if err == nil {
break
}
if opErr, isOpErr := err.(*net.OpError); isOpErr {
// if the error is not a temporary error (ex: network unreachable) don't
// retry
if !opErr.Temporary() {
break
}
}
if gocqlDebug {
Logger.Printf("connection failed %q: %v, reconnecting with %T\n",
pool.host.ConnectAddress(), err, reconnectionPolicy)
}
time.Sleep(reconnectionPolicy.GetInterval(i))
}
if err != nil {
return err
}
if pool.keyspace != "" {
// set the keyspace
if err = conn.UseKeyspace(pool.keyspace); err != nil {
conn.Close()
return err
}
}
// add the Conn to the pool
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
conn.Close()
return nil
}
pool.conns = append(pool.conns, conn)
return nil
}
// handle any error from a Conn
func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
if !closed {
// still an open connection, so continue using it
return
}
// TODO: track the number of errors per host and detect when a host is dead,
// then also have something which can detect when a host comes back.
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
// pool closed
return
}
// find the connection index
for i, candidate := range pool.conns {
if candidate == conn {
// remove the connection, not preserving order
pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
// lost a connection, so fill the pool
go pool.fill()
break
}
}
}

@ -0,0 +1,488 @@
package gocql
import (
"context"
crand "crypto/rand"
"errors"
"fmt"
"math/rand"
"net"
"os"
"regexp"
"strconv"
"sync"
"sync/atomic"
"time"
)
var (
randr *rand.Rand
mutRandr sync.Mutex
)
func init() {
b := make([]byte, 4)
if _, err := crand.Read(b); err != nil {
panic(fmt.Sprintf("unable to seed random number generator: %v", err))
}
randr = rand.New(rand.NewSource(int64(readInt(b))))
}
// Ensure that the atomic variable is aligned to a 64bit boundary
// so that atomic operations can be applied on 32bit architectures.
type controlConn struct {
started int32
reconnecting int32
session *Session
conn atomic.Value
retry RetryPolicy
quit chan struct{}
}
func createControlConn(session *Session) *controlConn {
control := &controlConn{
session: session,
quit: make(chan struct{}),
retry: &SimpleRetryPolicy{NumRetries: 3},
}
control.conn.Store((*connHost)(nil))
return control
}
func (c *controlConn) heartBeat() {
if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
return
}
sleepTime := 1 * time.Second
timer := time.NewTimer(sleepTime)
defer timer.Stop()
for {
timer.Reset(sleepTime)
select {
case <-c.quit:
return
case <-timer.C:
}
resp, err := c.writeFrame(&writeOptionsFrame{})
if err != nil {
goto reconn
}
switch resp.(type) {
case *supportedFrame:
// Everything ok
sleepTime = 5 * time.Second
continue
case error:
goto reconn
default:
panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
}
reconn:
// try to connect a bit faster
sleepTime = 1 * time.Second
c.reconnect(true)
continue
}
}
var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true"
func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
var port int
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
host = addr
port = defaultPort
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, err
}
}
var hosts []*HostInfo
// Check if host is a literal IP address
if ip := net.ParseIP(host); ip != nil {
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
return hosts, nil
}
// Look up host in DNS
ips, err := LookupIP(host)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
}
// Filter to v4 addresses if any present
if hostLookupPreferV4 {
var preferredIPs []net.IP
for _, v := range ips {
if v4 := v.To4(); v4 != nil {
preferredIPs = append(preferredIPs, v4)
}
}
if len(preferredIPs) != 0 {
ips = preferredIPs
}
}
for _, ip := range ips {
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
}
return hosts, nil
}
func shuffleHosts(hosts []*HostInfo) []*HostInfo {
shuffled := make([]*HostInfo, len(hosts))
copy(shuffled, hosts)
mutRandr.Lock()
randr.Shuffle(len(hosts), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})
mutRandr.Unlock()
return shuffled
}
func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
// shuffle endpoints so not all drivers will connect to the same initial
// node.
shuffled := shuffleHosts(endpoints)
cfg := *c.session.connCfg
cfg.disableCoalesce = true
var err error
for _, host := range shuffled {
var conn *Conn
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
if err == nil {
return conn, nil
}
Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err)
}
return nil, err
}
// this is going to be version dependant and a nightmare to maintain :(
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
func parseProtocolFromError(err error) int {
// I really wish this had the actual info in the error frame...
matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
if len(matches) != 1 || len(matches[0]) != 2 {
if verr, ok := err.(*protocolError); ok {
return int(verr.frame.Header().version.version())
}
return 0
}
max, err := strconv.Atoi(matches[0][1])
if err != nil {
return 0
}
return max
}
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
hosts = shuffleHosts(hosts)
connCfg := *c.session.connCfg
connCfg.ProtoVersion = 4 // TODO: define maxProtocol
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
// we should never get here, but if we do it means we connected to a
// host successfully which means our attempted protocol version worked
if !closed {
c.Close()
}
})
var err error
for _, host := range hosts {
var conn *Conn
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
if conn != nil {
conn.Close()
}
if err == nil {
return connCfg.ProtoVersion, nil
}
if proto := parseProtocolFromError(err); proto > 0 {
return proto, nil
}
}
return 0, err
}
func (c *controlConn) connect(hosts []*HostInfo) error {
if len(hosts) == 0 {
return errors.New("control: no endpoints specified")
}
conn, err := c.shuffleDial(hosts)
if err != nil {
return fmt.Errorf("control: unable to connect to initial hosts: %v", err)
}
if err := c.setupConn(conn); err != nil {
conn.Close()
return fmt.Errorf("control: unable to setup connection: %v", err)
}
// we could fetch the initial ring here and update initial host data. So that
// when we return from here we have a ring topology ready to go.
go c.heartBeat()
return nil
}
type connHost struct {
conn *Conn
host *HostInfo
}
func (c *controlConn) setupConn(conn *Conn) error {
if err := c.registerEvents(conn); err != nil {
conn.Close()
return err
}
// TODO(zariel): do we need to fetch host info everytime
// the control conn connects? Surely we have it cached?
host, err := conn.localHostInfo(context.TODO())
if err != nil {
return err
}
ch := &connHost{
conn: conn,
host: host,
}
c.conn.Store(ch)
c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
return nil
}
func (c *controlConn) registerEvents(conn *Conn) error {
var events []string
if !c.session.cfg.Events.DisableTopologyEvents {
events = append(events, "TOPOLOGY_CHANGE")
}
if !c.session.cfg.Events.DisableNodeStatusEvents {
events = append(events, "STATUS_CHANGE")
}
if !c.session.cfg.Events.DisableSchemaEvents {
events = append(events, "SCHEMA_CHANGE")
}
if len(events) == 0 {
return nil
}
framer, err := conn.exec(context.Background(),
&writeRegisterFrame{
events: events,
}, nil)
if err != nil {
return err
}
frame, err := framer.parseFrame()
if err != nil {
return err
} else if _, ok := frame.(*readyFrame); !ok {
return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
}
return nil
}
func (c *controlConn) reconnect(refreshring bool) {
if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
return
}
defer atomic.StoreInt32(&c.reconnecting, 0)
// TODO: simplify this function, use session.ring to get hosts instead of the
// connection pool
var host *HostInfo
ch := c.getConn()
if ch != nil {
host = ch.host
ch.conn.Close()
}
var newConn *Conn
if host != nil {
// try to connect to the old host
conn, err := c.session.connect(c.session.ctx, host, c)
if err != nil {
// host is dead
// TODO: this is replicated in a few places
if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
c.session.handleNodeDown(host.ConnectAddress(), host.Port())
}
} else {
newConn = conn
}
}
// TODO: should have our own round-robin for hosts so that we can try each
// in succession and guarantee that we get a different host each time.
if newConn == nil {
host := c.session.ring.rrHost()
if host == nil {
c.connect(c.session.ring.endpoints)
return
}
var err error
newConn, err = c.session.connect(c.session.ctx, host, c)
if err != nil {
// TODO: add log handler for things like this
return
}
}
if err := c.setupConn(newConn); err != nil {
newConn.Close()
Logger.Printf("gocql: control unable to register events: %v\n", err)
return
}
if refreshring {
c.session.hostSource.refreshRing()
}
}
func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
if !closed {
return
}
oldConn := c.getConn()
// If connection has long gone, and not been attempted for awhile,
// it's possible to have oldConn as nil here (#1297).
if oldConn != nil && oldConn.conn != conn {
return
}
c.reconnect(false)
}
func (c *controlConn) getConn() *connHost {
return c.conn.Load().(*connHost)
}
func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
ch := c.getConn()
if ch == nil {
return nil, errNoControl
}
framer, err := ch.conn.exec(context.Background(), w, nil)
if err != nil {
return nil, err
}
return framer.parseFrame()
}
func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
const maxConnectAttempts = 5
connectAttempts := 0
for i := 0; i < maxConnectAttempts; i++ {
ch := c.getConn()
if ch == nil {
if connectAttempts > maxConnectAttempts {
break
}
connectAttempts++
c.reconnect(false)
continue
}
return fn(ch)
}
return &Iter{err: errNoControl}
}
func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
return c.withConnHost(func(ch *connHost) *Iter {
return fn(ch.conn)
})
}
// query will return nil if the connection is closed or nil
func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
for {
iter = c.withConn(func(conn *Conn) *Iter {
return conn.executeQuery(context.TODO(), q)
})
if gocqlDebug && iter.err != nil {
Logger.Printf("control: error executing %q: %v\n", statement, iter.err)
}
q.AddAttempts(1, c.getConn().host)
if iter.err == nil || !c.retry.Attempt(q) {
break
}
}
return
}
func (c *controlConn) awaitSchemaAgreement() error {
return c.withConn(func(conn *Conn) *Iter {
return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
}).err
}
func (c *controlConn) close() {
if atomic.CompareAndSwapInt32(&c.started, 1, -1) {
c.quit <- struct{}{}
}
ch := c.getConn()
if ch != nil {
ch.conn.Close()
}
}
var errNoControl = errors.New("gocql: no control connection available")

@ -0,0 +1,11 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
type Duration struct {
Months int32
Days int32
Nanoseconds int64
}

@ -0,0 +1,5 @@
// +build !gocql_debug
package gocql
const gocqlDebug = false

@ -0,0 +1,5 @@
// +build gocql_debug
package gocql
const gocqlDebug = true

@ -0,0 +1,9 @@
// Copyright (c) 2012-2015 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gocql implements a fast and robust Cassandra driver for the
// Go programming language.
package gocql // import "github.com/gocql/gocql"
// TODO(tux21b): write more docs.

@ -0,0 +1,125 @@
package gocql
import "fmt"
const (
errServer = 0x0000
errProtocol = 0x000A
errCredentials = 0x0100
errUnavailable = 0x1000
errOverloaded = 0x1001
errBootstrapping = 0x1002
errTruncate = 0x1003
errWriteTimeout = 0x1100
errReadTimeout = 0x1200
errReadFailure = 0x1300
errFunctionFailure = 0x1400
errWriteFailure = 0x1500
errCDCWriteFailure = 0x1600
errSyntax = 0x2000
errUnauthorized = 0x2100
errInvalid = 0x2200
errConfig = 0x2300
errAlreadyExists = 0x2400
errUnprepared = 0x2500
)
type RequestError interface {
Code() int
Message() string
Error() string
}
type errorFrame struct {
frameHeader
code int
message string
}
func (e errorFrame) Code() int {
return e.code
}
func (e errorFrame) Message() string {
return e.message
}
func (e errorFrame) Error() string {
return e.Message()
}
func (e errorFrame) String() string {
return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message)
}
type RequestErrUnavailable struct {
errorFrame
Consistency Consistency
Required int
Alive int
}
func (e *RequestErrUnavailable) String() string {
return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive)
}
type ErrorMap map[string]uint16
type RequestErrWriteTimeout struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
WriteType string
}
type RequestErrWriteFailure struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
NumFailures int
WriteType string
ErrorMap ErrorMap
}
type RequestErrCDCWriteFailure struct {
errorFrame
}
type RequestErrReadTimeout struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
DataPresent byte
}
type RequestErrAlreadyExists struct {
errorFrame
Keyspace string
Table string
}
type RequestErrUnprepared struct {
errorFrame
StatementId []byte
}
type RequestErrReadFailure struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
NumFailures int
DataPresent bool
ErrorMap ErrorMap
}
type RequestErrFunctionFailure struct {
errorFrame
Keyspace string
Function string
ArgTypes []string
}

@ -0,0 +1,293 @@
package gocql
import (
"net"
"sync"
"time"
)
type eventDebouncer struct {
name string
timer *time.Timer
mu sync.Mutex
events []frame
callback func([]frame)
quit chan struct{}
}
func newEventDebouncer(name string, eventHandler func([]frame)) *eventDebouncer {
e := &eventDebouncer{
name: name,
quit: make(chan struct{}),
timer: time.NewTimer(eventDebounceTime),
callback: eventHandler,
}
e.timer.Stop()
go e.flusher()
return e
}
func (e *eventDebouncer) stop() {
e.quit <- struct{}{} // sync with flusher
close(e.quit)
}
func (e *eventDebouncer) flusher() {
for {
select {
case <-e.timer.C:
e.mu.Lock()
e.flush()
e.mu.Unlock()
case <-e.quit:
return
}
}
}
const (
eventBufferSize = 1000
eventDebounceTime = 1 * time.Second
)
// flush must be called with mu locked
func (e *eventDebouncer) flush() {
if len(e.events) == 0 {
return
}
// if the flush interval is faster than the callback then we will end up calling
// the callback multiple times, probably a bad idea. In this case we could drop
// frames?
go e.callback(e.events)
e.events = make([]frame, 0, eventBufferSize)
}
func (e *eventDebouncer) debounce(frame frame) {
e.mu.Lock()
e.timer.Reset(eventDebounceTime)
// TODO: probably need a warning to track if this threshold is too low
if len(e.events) < eventBufferSize {
e.events = append(e.events, frame)
} else {
Logger.Printf("%s: buffer full, dropping event frame: %s", e.name, frame)
}
e.mu.Unlock()
}
func (s *Session) handleEvent(framer *framer) {
frame, err := framer.parseFrame()
if err != nil {
// TODO: logger
Logger.Printf("gocql: unable to parse event frame: %v\n", err)
return
}
if gocqlDebug {
Logger.Printf("gocql: handling frame: %v\n", frame)
}
switch f := frame.(type) {
case *schemaChangeKeyspace, *schemaChangeFunction,
*schemaChangeTable, *schemaChangeAggregate, *schemaChangeType:
s.schemaEvents.debounce(frame)
case *topologyChangeEventFrame, *statusChangeEventFrame:
s.nodeEvents.debounce(frame)
default:
Logger.Printf("gocql: invalid event frame (%T): %v\n", f, f)
}
}
func (s *Session) handleSchemaEvent(frames []frame) {
// TODO: debounce events
for _, frame := range frames {
switch f := frame.(type) {
case *schemaChangeKeyspace:
s.schemaDescriber.clearSchema(f.keyspace)
s.handleKeyspaceChange(f.keyspace, f.change)
case *schemaChangeTable:
s.schemaDescriber.clearSchema(f.keyspace)
case *schemaChangeAggregate:
s.schemaDescriber.clearSchema(f.keyspace)
case *schemaChangeFunction:
s.schemaDescriber.clearSchema(f.keyspace)
case *schemaChangeType:
s.schemaDescriber.clearSchema(f.keyspace)
}
}
}
func (s *Session) handleKeyspaceChange(keyspace, change string) {
s.control.awaitSchemaAgreement()
s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change})
}
func (s *Session) handleNodeEvent(frames []frame) {
type nodeEvent struct {
change string
host net.IP
port int
}
events := make(map[string]*nodeEvent)
for _, frame := range frames {
// TODO: can we be sure the order of events in the buffer is correct?
switch f := frame.(type) {
case *topologyChangeEventFrame:
event, ok := events[f.host.String()]
if !ok {
event = &nodeEvent{change: f.change, host: f.host, port: f.port}
events[f.host.String()] = event
}
event.change = f.change
case *statusChangeEventFrame:
event, ok := events[f.host.String()]
if !ok {
event = &nodeEvent{change: f.change, host: f.host, port: f.port}
events[f.host.String()] = event
}
event.change = f.change
}
}
for _, f := range events {
if gocqlDebug {
Logger.Printf("gocql: dispatching event: %+v\n", f)
}
switch f.change {
case "NEW_NODE":
s.handleNewNode(f.host, f.port, true)
case "REMOVED_NODE":
s.handleRemovedNode(f.host, f.port)
case "MOVED_NODE":
// java-driver handles this, not mentioned in the spec
// TODO(zariel): refresh token map
case "UP":
s.handleNodeUp(f.host, f.port, true)
case "DOWN":
s.handleNodeDown(f.host, f.port)
}
}
}
func (s *Session) addNewNode(host *HostInfo) {
if s.cfg.filterHost(host) {
return
}
host.setState(NodeUp)
s.pool.addHost(host)
s.policy.AddHost(host)
}
func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
}
ip, port = s.cfg.translateAddressPort(ip, port)
// Get host info and apply any filters to the host
hostInfo, err := s.hostSource.getHostInfo(ip, port)
if err != nil {
Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
return
} else if hostInfo == nil {
// If hostInfo is nil, this host was filtered out by cfg.HostFilter
return
}
if t := hostInfo.Version().nodeUpDelay(); t > 0 && waitForBinary {
time.Sleep(t)
}
// should this handle token moving?
hostInfo = s.ring.addOrUpdate(hostInfo)
s.addNewNode(hostInfo)
if s.control != nil && !s.cfg.IgnorePeerAddr {
// TODO(zariel): debounce ring refresh
s.hostSource.refreshRing()
}
}
func (s *Session) handleRemovedNode(ip net.IP, port int) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
}
ip, port = s.cfg.translateAddressPort(ip, port)
// we remove all nodes but only add ones which pass the filter
host := s.ring.getHost(ip)
if host == nil {
host = &HostInfo{connectAddress: ip, port: port}
}
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
return
}
host.setState(NodeDown)
s.policy.RemoveHost(host)
s.pool.removeHost(ip)
s.ring.removeHost(ip)
if !s.cfg.IgnorePeerAddr {
s.hostSource.refreshRing()
}
}
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
}
ip, _ := s.cfg.translateAddressPort(eventIp, eventPort)
host := s.ring.getHost(ip)
if host == nil {
// TODO(zariel): avoid the need to translate twice in this
// case
s.handleNewNode(eventIp, eventPort, waitForBinary)
return
}
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
return
}
if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
time.Sleep(t)
}
s.addNewNode(host)
}
func (s *Session) handleNodeDown(ip net.IP, port int) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
}
host := s.ring.getHost(ip)
if host == nil {
host = &HostInfo{connectAddress: ip, port: port}
}
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
return
}
host.setState(NodeDown)
s.policy.HostDown(host)
s.pool.hostDown(ip)
}

@ -0,0 +1,57 @@
package gocql
import "fmt"
// HostFilter interface is used when a host is discovered via server sent events.
type HostFilter interface {
// Called when a new host is discovered, returning true will cause the host
// to be added to the pools.
Accept(host *HostInfo) bool
}
// HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
type HostFilterFunc func(host *HostInfo) bool
func (fn HostFilterFunc) Accept(host *HostInfo) bool {
return fn(host)
}
// AcceptAllFilter will accept all hosts
func AcceptAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return true
})
}
func DenyAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return false
})
}
// DataCentreHostFilter filters all hosts such that they are in the same data centre
// as the supplied data centre.
func DataCentreHostFilter(dataCentre string) HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return host.DataCenter() == dataCentre
})
}
// WhiteListHostFilter filters incoming hosts by checking that their address is
// in the initial hosts whitelist.
func WhiteListHostFilter(hosts ...string) HostFilter {
hostInfos, err := addrsToHosts(hosts, 9042)
if err != nil {
// dont want to panic here, but rather not break the API
panic(fmt.Errorf("unable to lookup host info from address: %v", err))
}
m := make(map[string]bool, len(hostInfos))
for _, host := range hostInfos {
m[host.ConnectAddress().String()] = true
}
return HostFilterFunc(func(host *HostInfo) bool {
return m[host.ConnectAddress().String()]
})
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,33 @@
// +build gofuzz
package gocql
import "bytes"
func Fuzz(data []byte) int {
var bw bytes.Buffer
r := bytes.NewReader(data)
head, err := readHeader(r, make([]byte, 9))
if err != nil {
return 0
}
framer := newFramer(r, &bw, nil, byte(head.version))
err = framer.readFrame(&head)
if err != nil {
return 0
}
frame, err := framer.parseFrame()
if err != nil {
return 0
}
if frame != nil {
return 1
}
return 2
}

@ -0,0 +1,13 @@
module github.com/gocql/gocql
require (
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect
github.com/golang/snappy v0.0.0-20170215233205-553a64147049
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed
github.com/kr/pretty v0.1.0 // indirect
github.com/stretchr/testify v1.3.0 // indirect
gopkg.in/inf.v0 v0.9.1
)
go 1.13

@ -0,0 +1,22 @@
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=

@ -0,0 +1,432 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"fmt"
"math/big"
"net"
"reflect"
"strings"
"time"
"gopkg.in/inf.v0"
)
type RowData struct {
Columns []string
Values []interface{}
}
func goType(t TypeInfo) reflect.Type {
switch t.Type() {
case TypeVarchar, TypeAscii, TypeInet, TypeText:
return reflect.TypeOf(*new(string))
case TypeBigInt, TypeCounter:
return reflect.TypeOf(*new(int64))
case TypeTime:
return reflect.TypeOf(*new(time.Duration))
case TypeTimestamp:
return reflect.TypeOf(*new(time.Time))
case TypeBlob:
return reflect.TypeOf(*new([]byte))
case TypeBoolean:
return reflect.TypeOf(*new(bool))
case TypeFloat:
return reflect.TypeOf(*new(float32))
case TypeDouble:
return reflect.TypeOf(*new(float64))
case TypeInt:
return reflect.TypeOf(*new(int))
case TypeSmallInt:
return reflect.TypeOf(*new(int16))
case TypeTinyInt:
return reflect.TypeOf(*new(int8))
case TypeDecimal:
return reflect.TypeOf(*new(*inf.Dec))
case TypeUUID, TypeTimeUUID:
return reflect.TypeOf(*new(UUID))
case TypeList, TypeSet:
return reflect.SliceOf(goType(t.(CollectionType).Elem))
case TypeMap:
return reflect.MapOf(goType(t.(CollectionType).Key), goType(t.(CollectionType).Elem))
case TypeVarint:
return reflect.TypeOf(*new(*big.Int))
case TypeTuple:
// what can we do here? all there is to do is to make a list of interface{}
tuple := t.(TupleTypeInfo)
return reflect.TypeOf(make([]interface{}, len(tuple.Elems)))
case TypeUDT:
return reflect.TypeOf(make(map[string]interface{}))
case TypeDate:
return reflect.TypeOf(*new(time.Time))
case TypeDuration:
return reflect.TypeOf(*new(Duration))
default:
return nil
}
}
func dereference(i interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(i)).Interface()
}
func getCassandraBaseType(name string) Type {
switch name {
case "ascii":
return TypeAscii
case "bigint":
return TypeBigInt
case "blob":
return TypeBlob
case "boolean":
return TypeBoolean
case "counter":
return TypeCounter
case "decimal":
return TypeDecimal
case "double":
return TypeDouble
case "float":
return TypeFloat
case "int":
return TypeInt
case "tinyint":
return TypeTinyInt
case "time":
return TypeTime
case "timestamp":
return TypeTimestamp
case "uuid":
return TypeUUID
case "varchar":
return TypeVarchar
case "text":
return TypeText
case "varint":
return TypeVarint
case "timeuuid":
return TypeTimeUUID
case "inet":
return TypeInet
case "MapType":
return TypeMap
case "ListType":
return TypeList
case "SetType":
return TypeSet
case "TupleType":
return TypeTuple
default:
return TypeCustom
}
}
func getCassandraType(name string) TypeInfo {
if strings.HasPrefix(name, "frozen<") {
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"))
} else if strings.HasPrefix(name, "set<") {
return CollectionType{
NativeType: NativeType{typ: TypeSet},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<")),
}
} else if strings.HasPrefix(name, "list<") {
return CollectionType{
NativeType: NativeType{typ: TypeList},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<")),
}
} else if strings.HasPrefix(name, "map<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
if len(names) != 2 {
Logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
return NativeType{
typ: TypeCustom,
}
}
return CollectionType{
NativeType: NativeType{typ: TypeMap},
Key: getCassandraType(names[0]),
Elem: getCassandraType(names[1]),
}
} else if strings.HasPrefix(name, "tuple<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
types := make([]TypeInfo, len(names))
for i, name := range names {
types[i] = getCassandraType(name)
}
return TupleTypeInfo{
NativeType: NativeType{typ: TypeTuple},
Elems: types,
}
} else {
return NativeType{
typ: getCassandraBaseType(name),
}
}
}
func splitCompositeTypes(name string) []string {
if !strings.Contains(name, "<") {
return strings.Split(name, ", ")
}
var parts []string
lessCount := 0
segment := ""
for _, char := range name {
if char == ',' && lessCount == 0 {
if segment != "" {
parts = append(parts, strings.TrimSpace(segment))
}
segment = ""
continue
}
segment += string(char)
if char == '<' {
lessCount++
} else if char == '>' {
lessCount--
}
}
if segment != "" {
parts = append(parts, strings.TrimSpace(segment))
}
return parts
}
func apacheToCassandraType(t string) string {
t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
t = strings.Replace(t, "(", "<", -1)
t = strings.Replace(t, ")", ">", -1)
types := strings.FieldsFunc(t, func(r rune) bool {
return r == '<' || r == '>' || r == ','
})
for _, typ := range types {
t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1)
}
// This is done so it exactly matches what Cassandra returns
return strings.Replace(t, ",", ", ", -1)
}
func getApacheCassandraType(class string) Type {
switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
case "AsciiType":
return TypeAscii
case "LongType":
return TypeBigInt
case "BytesType":
return TypeBlob
case "BooleanType":
return TypeBoolean
case "CounterColumnType":
return TypeCounter
case "DecimalType":
return TypeDecimal
case "DoubleType":
return TypeDouble
case "FloatType":
return TypeFloat
case "Int32Type":
return TypeInt
case "ShortType":
return TypeSmallInt
case "ByteType":
return TypeTinyInt
case "TimeType":
return TypeTime
case "DateType", "TimestampType":
return TypeTimestamp
case "UUIDType", "LexicalUUIDType":
return TypeUUID
case "UTF8Type":
return TypeVarchar
case "IntegerType":
return TypeVarint
case "TimeUUIDType":
return TypeTimeUUID
case "InetAddressType":
return TypeInet
case "MapType":
return TypeMap
case "ListType":
return TypeList
case "SetType":
return TypeSet
case "TupleType":
return TypeTuple
case "DurationType":
return TypeDuration
default:
return TypeCustom
}
}
func typeCanBeNull(typ TypeInfo) bool {
switch typ.(type) {
case CollectionType, UDTTypeInfo, TupleTypeInfo:
return false
}
return true
}
func (r *RowData) rowMap(m map[string]interface{}) {
for i, column := range r.Columns {
val := dereference(r.Values[i])
if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice {
valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap())
reflect.Copy(valCopy, valVal)
m[column] = valCopy.Interface()
} else {
m[column] = val
}
}
}
// TupeColumnName will return the column name of a tuple value in a column named
// c at index n. It should be used if a specific element within a tuple is needed
// to be extracted from a map returned from SliceMap or MapScan.
func TupleColumnName(c string, n int) string {
return fmt.Sprintf("%s[%d]", c, n)
}
func (iter *Iter) RowData() (RowData, error) {
if iter.err != nil {
return RowData{}, iter.err
}
columns := make([]string, 0, len(iter.Columns()))
values := make([]interface{}, 0, len(iter.Columns()))
for _, column := range iter.Columns() {
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
val := column.TypeInfo.New()
columns = append(columns, column.Name)
values = append(values, val)
} else {
for i, elem := range c.Elems {
columns = append(columns, TupleColumnName(column.Name, i))
values = append(values, elem.New())
}
}
}
rowData := RowData{
Columns: columns,
Values: values,
}
return rowData, nil
}
// TODO(zariel): is it worth exporting this?
func (iter *Iter) rowMap() (map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
}
rowData, _ := iter.RowData()
iter.Scan(rowData.Values...)
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
return m, nil
}
// SliceMap is a helper function to make the API easier to use
// returns the data from the query in the form of []map[string]interface{}
func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
}
// Not checking for the error because we just did
rowData, _ := iter.RowData()
dataToReturn := make([]map[string]interface{}, 0)
for iter.Scan(rowData.Values...) {
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
dataToReturn = append(dataToReturn, m)
}
if iter.err != nil {
return nil, iter.err
}
return dataToReturn, nil
}
// MapScan takes a map[string]interface{} and populates it with a row
// that is returned from cassandra.
//
// Each call to MapScan() must be called with a new map object.
// During the call to MapScan() any pointers in the existing map
// are replaced with non pointer types before the call returns
//
// iter := session.Query(`SELECT * FROM mytable`).Iter()
// for {
// // New map each iteration
// row = make(map[string]interface{})
// if !iter.MapScan(row) {
// break
// }
// // Do things with row
// if fullname, ok := row["fullname"]; ok {
// fmt.Printf("Full Name: %s\n", fullname)
// }
// }
//
// You can also pass pointers in the map before each call
//
// var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces
// var address net.IP
// var age int
// iter := session.Query(`SELECT * FROM scan_map_table`).Iter()
// for {
// // New map each iteration
// row := map[string]interface{}{
// "fullname": &fullName,
// "age": &age,
// "address": &address,
// }
// if !iter.MapScan(row) {
// break
// }
// fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address)
// }
func (iter *Iter) MapScan(m map[string]interface{}) bool {
if iter.err != nil {
return false
}
// Not checking for the error because we just did
rowData, _ := iter.RowData()
for i, col := range rowData.Columns {
if dest, ok := m[col]; ok {
rowData.Values[i] = dest
}
}
if iter.Scan(rowData.Values...) {
rowData.rowMap(m)
return true
}
return false
}
func copyBytes(p []byte) []byte {
b := make([]byte, len(p))
copy(b, p)
return b
}
var failDNS = false
func LookupIP(host string) ([]net.IP, error) {
if failDNS {
return nil, &net.DNSError{}
}
return net.LookupIP(host)
}

@ -0,0 +1,715 @@
package gocql
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
)
type nodeState int32
func (n nodeState) String() string {
if n == NodeUp {
return "UP"
} else if n == NodeDown {
return "DOWN"
}
return fmt.Sprintf("UNKNOWN_%d", n)
}
const (
NodeUp nodeState = iota
NodeDown
)
type cassVersion struct {
Major, Minor, Patch int
}
func (c *cassVersion) Set(v string) error {
if v == "" {
return nil
}
return c.UnmarshalCQL(nil, []byte(v))
}
func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
return c.unmarshal(data)
}
func (c *cassVersion) unmarshal(data []byte) error {
version := strings.TrimSuffix(string(data), "-SNAPSHOT")
version = strings.TrimPrefix(version, "v")
v := strings.Split(version, ".")
if len(v) < 2 {
return fmt.Errorf("invalid version string: %s", data)
}
var err error
c.Major, err = strconv.Atoi(v[0])
if err != nil {
return fmt.Errorf("invalid major version %v: %v", v[0], err)
}
c.Minor, err = strconv.Atoi(v[1])
if err != nil {
return fmt.Errorf("invalid minor version %v: %v", v[1], err)
}
if len(v) > 2 {
c.Patch, err = strconv.Atoi(v[2])
if err != nil {
return fmt.Errorf("invalid patch version %v: %v", v[2], err)
}
}
return nil
}
func (c cassVersion) Before(major, minor, patch int) bool {
// We're comparing us (cassVersion) with the provided version (major, minor, patch)
// We return true if our version is lower (comes before) than the provided one.
if c.Major < major {
return true
} else if c.Major == major {
if c.Minor < minor {
return true
} else if c.Minor == minor && c.Patch < patch {
return true
}
}
return false
}
func (c cassVersion) AtLeast(major, minor, patch int) bool {
return !c.Before(major, minor, patch)
}
func (c cassVersion) String() string {
return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
}
func (c cassVersion) nodeUpDelay() time.Duration {
if c.Major >= 2 && c.Minor >= 2 {
// CASSANDRA-8236
return 0
}
return 10 * time.Second
}
type HostInfo struct {
// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
// that we are thread safe use a mutex to access all fields.
mu sync.RWMutex
hostname string
peer net.IP
broadcastAddress net.IP
listenAddress net.IP
rpcAddress net.IP
preferredIP net.IP
connectAddress net.IP
port int
dataCenter string
rack string
hostId string
workload string
graph bool
dseVersion string
partitioner string
clusterName string
version cassVersion
state nodeState
schemaVersion string
tokens []string
}
func (h *HostInfo) Equal(host *HostInfo) bool {
if h == host {
// prevent rlock reentry
return true
}
return h.ConnectAddress().Equal(host.ConnectAddress())
}
func (h *HostInfo) Peer() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.peer
}
func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.peer = peer
return h
}
func (h *HostInfo) invalidConnectAddr() bool {
h.mu.RLock()
defer h.mu.RUnlock()
addr, _ := h.connectAddressLocked()
return !validIpAddr(addr)
}
func validIpAddr(addr net.IP) bool {
return addr != nil && !addr.IsUnspecified()
}
func (h *HostInfo) connectAddressLocked() (net.IP, string) {
if validIpAddr(h.connectAddress) {
return h.connectAddress, "connect_address"
} else if validIpAddr(h.rpcAddress) {
return h.rpcAddress, "rpc_adress"
} else if validIpAddr(h.preferredIP) {
// where does perferred_ip get set?
return h.preferredIP, "preferred_ip"
} else if validIpAddr(h.broadcastAddress) {
return h.broadcastAddress, "broadcast_address"
} else if validIpAddr(h.peer) {
return h.peer, "peer"
}
return net.IPv4zero, "invalid"
}
// Returns the address that should be used to connect to the host.
// If you wish to override this, use an AddressTranslator or
// use a HostFilter to SetConnectAddress()
func (h *HostInfo) ConnectAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
return addr
}
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
}
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
// TODO(zariel): should this not be exported?
h.mu.Lock()
defer h.mu.Unlock()
h.connectAddress = address
return h
}
func (h *HostInfo) BroadcastAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.broadcastAddress
}
func (h *HostInfo) ListenAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.listenAddress
}
func (h *HostInfo) RPCAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.rpcAddress
}
func (h *HostInfo) PreferredIP() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.preferredIP
}
func (h *HostInfo) DataCenter() string {
h.mu.RLock()
dc := h.dataCenter
h.mu.RUnlock()
return dc
}
func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.dataCenter = dataCenter
return h
}
func (h *HostInfo) Rack() string {
h.mu.RLock()
rack := h.rack
h.mu.RUnlock()
return rack
}
func (h *HostInfo) setRack(rack string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.rack = rack
return h
}
func (h *HostInfo) HostID() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.hostId
}
func (h *HostInfo) setHostID(hostID string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.hostId = hostID
return h
}
func (h *HostInfo) WorkLoad() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.workload
}
func (h *HostInfo) Graph() bool {
h.mu.RLock()
defer h.mu.RUnlock()
return h.graph
}
func (h *HostInfo) DSEVersion() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.dseVersion
}
func (h *HostInfo) Partitioner() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.partitioner
}
func (h *HostInfo) ClusterName() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.clusterName
}
func (h *HostInfo) Version() cassVersion {
h.mu.RLock()
defer h.mu.RUnlock()
return h.version
}
func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.version = cassVersion{major, minor, patch}
return h
}
func (h *HostInfo) State() nodeState {
h.mu.RLock()
defer h.mu.RUnlock()
return h.state
}
func (h *HostInfo) setState(state nodeState) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.state = state
return h
}
func (h *HostInfo) Tokens() []string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.tokens
}
func (h *HostInfo) setTokens(tokens []string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.tokens = tokens
return h
}
func (h *HostInfo) Port() int {
h.mu.RLock()
defer h.mu.RUnlock()
return h.port
}
func (h *HostInfo) setPort(port int) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.port = port
return h
}
func (h *HostInfo) update(from *HostInfo) {
if h == from {
return
}
h.mu.Lock()
defer h.mu.Unlock()
from.mu.RLock()
defer from.mu.RUnlock()
// autogenerated do not update
if h.peer == nil {
h.peer = from.peer
}
if h.broadcastAddress == nil {
h.broadcastAddress = from.broadcastAddress
}
if h.listenAddress == nil {
h.listenAddress = from.listenAddress
}
if h.rpcAddress == nil {
h.rpcAddress = from.rpcAddress
}
if h.preferredIP == nil {
h.preferredIP = from.preferredIP
}
if h.connectAddress == nil {
h.connectAddress = from.connectAddress
}
if h.port == 0 {
h.port = from.port
}
if h.dataCenter == "" {
h.dataCenter = from.dataCenter
}
if h.rack == "" {
h.rack = from.rack
}
if h.hostId == "" {
h.hostId = from.hostId
}
if h.workload == "" {
h.workload = from.workload
}
if h.dseVersion == "" {
h.dseVersion = from.dseVersion
}
if h.partitioner == "" {
h.partitioner = from.partitioner
}
if h.clusterName == "" {
h.clusterName = from.clusterName
}
if h.version == (cassVersion{}) {
h.version = from.version
}
if h.tokens == nil {
h.tokens = from.tokens
}
}
func (h *HostInfo) IsUp() bool {
return h != nil && h.State() == NodeUp
}
func (h *HostInfo) HostnameAndPort() string {
if h.hostname == "" {
h.hostname = h.ConnectAddress().String()
}
return net.JoinHostPort(h.hostname, strconv.Itoa(h.port))
}
func (h *HostInfo) String() string {
h.mu.RLock()
defer h.mu.RUnlock()
connectAddr, source := h.connectAddressLocked()
return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
"preferred_ip=%q connect_addr=%q connect_addr_source=%q "+
"port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]",
h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
connectAddr, source,
h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
}
// Polls system.peers at a specific interval to find new hosts
type ringDescriber struct {
session *Session
mu sync.Mutex
prevHosts []*HostInfo
prevPartitioner string
}
// Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces
func checkSystemSchema(control *controlConn) (bool, error) {
iter := control.query("SELECT * FROM system_schema.keyspaces")
if err := iter.err; err != nil {
if errf, ok := err.(*errorFrame); ok {
if errf.code == errSyntax {
return false, nil
}
}
return false, err
}
return true, nil
}
// Given a map that represents a row from either system.local or system.peers
// return as much information as we can in *HostInfo
func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) {
const assertErrorMsg = "Assertion failed for %s"
var ok bool
// Default to our connected port if the cluster doesn't have port information
for key, value := range row {
switch key {
case "data_center":
host.dataCenter, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "data_center")
}
case "rack":
host.rack, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "rack")
}
case "host_id":
hostId, ok := value.(UUID)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "host_id")
}
host.hostId = hostId.String()
case "release_version":
version, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "release_version")
}
host.version.Set(version)
case "peer":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "peer")
}
host.peer = net.ParseIP(ip)
case "cluster_name":
host.clusterName, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "cluster_name")
}
case "partitioner":
host.partitioner, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "partitioner")
}
case "broadcast_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "broadcast_address")
}
host.broadcastAddress = net.ParseIP(ip)
case "preferred_ip":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "preferred_ip")
}
host.preferredIP = net.ParseIP(ip)
case "rpc_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "rpc_address")
}
host.rpcAddress = net.ParseIP(ip)
case "listen_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "listen_address")
}
host.listenAddress = net.ParseIP(ip)
case "workload":
host.workload, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "workload")
}
case "graph":
host.graph, ok = value.(bool)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "graph")
}
case "tokens":
host.tokens, ok = value.([]string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "tokens")
}
case "dse_version":
host.dseVersion, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "dse_version")
}
case "schema_version":
schemaVersion, ok := value.(UUID)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "schema_version")
}
host.schemaVersion = schemaVersion.String()
}
// TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete
// Not sure what the port field will be called until the JIRA issue is complete
}
ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
host.connectAddress = ip
host.port = port
return host, nil
}
// Ask the control node for host info on all it's known peers
func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
var hosts []*HostInfo
iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
hosts = append(hosts, ch.host)
return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
})
if iter == nil {
return nil, errNoControl
}
rows, err := iter.SliceMap()
if err != nil {
// TODO(zariel): make typed error
return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
}
for _, row := range rows {
// extract all available info about the peer
host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
if err != nil {
return nil, err
} else if !isValidPeer(host) {
// If it's not a valid peer
Logger.Printf("Found invalid peer '%s' "+
"Likely due to a gossip or snitch issue, this host will be ignored", host)
continue
}
hosts = append(hosts, host)
}
return hosts, nil
}
// Return true if the host is a valid peer
func isValidPeer(host *HostInfo) bool {
return !(len(host.RPCAddress()) == 0 ||
host.hostId == "" ||
host.dataCenter == "" ||
host.rack == "" ||
len(host.tokens) == 0)
}
// Return a list of hosts the cluster knows about
func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) {
r.mu.Lock()
defer r.mu.Unlock()
hosts, err := r.getClusterPeerInfo()
if err != nil {
return r.prevHosts, r.prevPartitioner, err
}
var partitioner string
if len(hosts) > 0 {
partitioner = hosts[0].Partitioner()
}
return hosts, partitioner, nil
}
// Given an ip/port return HostInfo for the specified ip/port
func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
var host *HostInfo
iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
if ch.host.ConnectAddress().Equal(ip) {
host = ch.host
return nil
}
return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
})
if iter != nil {
rows, err := iter.SliceMap()
if err != nil {
return nil, err
}
for _, row := range rows {
h, err := r.session.hostInfoFromMap(row, &HostInfo{port: port})
if err != nil {
return nil, err
}
if h.ConnectAddress().Equal(ip) {
host = h
break
}
}
if host == nil {
return nil, errors.New("host not found in peers table")
}
}
if host == nil {
return nil, errors.New("unable to fetch host info: invalid control connection")
} else if host.invalidConnectAddr() {
return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", ip, host)
}
return host, nil
}
func (r *ringDescriber) refreshRing() error {
// if we have 0 hosts this will return the previous list of hosts to
// attempt to reconnect to the cluster otherwise we would never find
// downed hosts again, could possibly have an optimisation to only
// try to add new hosts if GetHosts didnt error and the hosts didnt change.
hosts, partitioner, err := r.GetHosts()
if err != nil {
return err
}
prevHosts := r.session.ring.currentHosts()
// TODO: move this to session
for _, h := range hosts {
if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
continue
}
if host, ok := r.session.ring.addHostIfMissing(h); !ok {
r.session.pool.addHost(h)
r.session.policy.AddHost(h)
} else {
host.update(h)
}
delete(prevHosts, h.ConnectAddress().String())
}
// TODO(zariel): it may be worth having a mutex covering the overall ring state
// in a session so that everything sees a consistent state. Becuase as is today
// events can come in and due to ordering an UP host could be removed from the cluster
for _, host := range prevHosts {
r.session.removeHost(host)
}
r.session.metadata.setPartitioner(partitioner)
r.session.policy.SetPartitioner(partitioner)
return nil
}

@ -0,0 +1,45 @@
// +build genhostinfo
package main
import (
"fmt"
"reflect"
"sync"
"github.com/gocql/gocql"
)
func gen(clause, field string) {
fmt.Printf("if h.%s == %s {\n", field, clause)
fmt.Printf("\th.%s = from.%s\n", field, field)
fmt.Println("}")
}
func main() {
t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type()
mu := reflect.TypeOf(sync.RWMutex{})
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type == mu {
continue
}
switch f.Type.Kind() {
case reflect.Slice:
gen("nil", f.Name)
case reflect.String:
gen(`""`, f.Name)
case reflect.Int:
gen("0", f.Name)
case reflect.Struct:
gen("("+f.Type.Name()+"{})", f.Name)
case reflect.Bool, reflect.Int32:
continue
default:
panic(fmt.Sprintf("unknown field: %s", f))
}
}
}

@ -0,0 +1,16 @@
#!/usr/bin/env bash
# This is not supposed to be an error-prone script; just a convenience.
# Install CCM
pip install --user cql PyYAML six
git clone https://github.com/pcmanus/ccm.git
pushd ccm
./setup.py install --user
popd
if [ "$1" != "gocql/gocql" ]; then
USER=$(echo $1 | cut -f1 -d'/')
cd ../..
mv ${USER} gocql
fi

@ -0,0 +1,95 @@
#!/bin/bash
set -eux
function run_tests() {
local clusterSize=3
local version=$1
local auth=$2
if [ "$auth" = true ]; then
clusterSize=1
fi
local keypath="$(pwd)/testdata/pki"
local conf=(
"client_encryption_options.enabled: true"
"client_encryption_options.keystore: $keypath/.keystore"
"client_encryption_options.keystore_password: cassandra"
"client_encryption_options.require_client_auth: true"
"client_encryption_options.truststore: $keypath/.truststore"
"client_encryption_options.truststore_password: cassandra"
"concurrent_reads: 2"
"concurrent_writes: 2"
"rpc_server_type: sync"
"rpc_min_threads: 2"
"rpc_max_threads: 2"
"write_request_timeout_in_ms: 5000"
"read_request_timeout_in_ms: 5000"
)
ccm remove test || true
ccm create test -v $version -n $clusterSize -d --vnodes --jvm_arg="-Xmx256m -XX:NewSize=100m"
ccm updateconf "${conf[@]}"
if [ "$auth" = true ]
then
ccm updateconf 'authenticator: PasswordAuthenticator' 'authorizer: CassandraAuthorizer'
rm -rf $HOME/.ccm/test/node1/data/system_auth
fi
local proto=2
if [[ $version == 1.2.* ]]; then
proto=1
elif [[ $version == 2.0.* ]]; then
proto=2
elif [[ $version == 2.1.* ]]; then
proto=3
elif [[ $version == 2.2.* || $version == 3.0.* ]]; then
proto=4
ccm updateconf 'enable_user_defined_functions: true'
export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"
elif [[ $version == 3.*.* ]]; then
proto=5
ccm updateconf 'enable_user_defined_functions: true'
export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"
fi
sleep 1s
ccm list
ccm start --wait-for-binary-proto
ccm status
ccm node1 nodetool status
local args="-gocql.timeout=60s -runssl -proto=$proto -rf=3 -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy -gocql.cversion=$version -cluster=$(ccm liveset) ./..."
go test -v -tags unit -race
if [ "$auth" = true ]
then
sleep 30s
go test -run=TestAuthentication -tags "integration gocql_debug" -timeout=15s -runauth $args
else
sleep 1s
go test -tags "cassandra gocql_debug" -timeout=5m -race $args
ccm clear
ccm start --wait-for-binary-proto
sleep 1s
go test -tags "integration gocql_debug" -timeout=5m -race $args
ccm clear
ccm start --wait-for-binary-proto
sleep 1s
go test -tags "ccm gocql_debug" -timeout=5m -race $args
fi
ccm remove
}
run_tests $1 $2

@ -0,0 +1,127 @@
/*
Copyright 2015 To gocql authors
Copyright 2013 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package lru implements an LRU cache.
package lru
import "container/list"
// Cache is an LRU cache. It is not safe for concurrent access.
//
// This cache has been forked from github.com/golang/groupcache/lru, but
// specialized with string keys to avoid the allocations caused by wrapping them
// in interface{}.
type Cache struct {
// MaxEntries is the maximum number of cache entries before
// an item is evicted. Zero means no limit.
MaxEntries int
// OnEvicted optionally specifies a callback function to be
// executed when an entry is purged from the cache.
OnEvicted func(key string, value interface{})
ll *list.List
cache map[string]*list.Element
}
type entry struct {
key string
value interface{}
}
// New creates a new Cache.
// If maxEntries is zero, the cache has no limit and it's assumed
// that eviction is done by the caller.
func New(maxEntries int) *Cache {
return &Cache{
MaxEntries: maxEntries,
ll: list.New(),
cache: make(map[string]*list.Element),
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value interface{}) {
if c.cache == nil {
c.cache = make(map[string]*list.Element)
c.ll = list.New()
}
if ee, ok := c.cache[key]; ok {
c.ll.MoveToFront(ee)
ee.Value.(*entry).value = value
return
}
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries {
c.RemoveOldest()
}
}
// Get looks up a key's value from the cache.
func (c *Cache) Get(key string) (value interface{}, ok bool) {
if c.cache == nil {
return
}
if ele, hit := c.cache[key]; hit {
c.ll.MoveToFront(ele)
return ele.Value.(*entry).value, true
}
return
}
// Remove removes the provided key from the cache.
func (c *Cache) Remove(key string) bool {
if c.cache == nil {
return false
}
if ele, hit := c.cache[key]; hit {
c.removeElement(ele)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *Cache) RemoveOldest() {
if c.cache == nil {
return
}
ele := c.ll.Back()
if ele != nil {
c.removeElement(ele)
}
}
func (c *Cache) removeElement(e *list.Element) {
c.ll.Remove(e)
kv := e.Value.(*entry)
delete(c.cache, kv.key)
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
// Len returns the number of items in the cache.
func (c *Cache) Len() int {
if c.cache == nil {
return 0
}
return c.ll.Len()
}

@ -0,0 +1,135 @@
package murmur
const (
c1 int64 = -8663945395140668459 // 0x87c37b91114253d5
c2 int64 = 5545529020109919103 // 0x4cf5ad432745937f
fmix1 int64 = -49064778989728563 // 0xff51afd7ed558ccd
fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53
)
func fmix(n int64) int64 {
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
n ^= int64(uint64(n) >> 33)
n *= fmix1
n ^= int64(uint64(n) >> 33)
n *= fmix2
n ^= int64(uint64(n) >> 33)
return n
}
func block(p byte) int64 {
return int64(int8(p))
}
func rotl(x int64, r uint8) int64 {
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
return (x << r) | (int64)((uint64(x) >> (64 - r)))
}
func Murmur3H1(data []byte) int64 {
length := len(data)
var h1, h2, k1, k2 int64
// body
nBlocks := length / 16
for i := 0; i < nBlocks; i++ {
k1, k2 = getBlock(data, i)
k1 *= c1
k1 = rotl(k1, 31)
k1 *= c2
h1 ^= k1
h1 = rotl(h1, 27)
h1 += h2
h1 = h1*5 + 0x52dce729
k2 *= c2
k2 = rotl(k2, 33)
k2 *= c1
h2 ^= k2
h2 = rotl(h2, 31)
h2 += h1
h2 = h2*5 + 0x38495ab5
}
// tail
tail := data[nBlocks*16:]
k1 = 0
k2 = 0
switch length & 15 {
case 15:
k2 ^= block(tail[14]) << 48
fallthrough
case 14:
k2 ^= block(tail[13]) << 40
fallthrough
case 13:
k2 ^= block(tail[12]) << 32
fallthrough
case 12:
k2 ^= block(tail[11]) << 24
fallthrough
case 11:
k2 ^= block(tail[10]) << 16
fallthrough
case 10:
k2 ^= block(tail[9]) << 8
fallthrough
case 9:
k2 ^= block(tail[8])
k2 *= c2
k2 = rotl(k2, 33)
k2 *= c1
h2 ^= k2
fallthrough
case 8:
k1 ^= block(tail[7]) << 56
fallthrough
case 7:
k1 ^= block(tail[6]) << 48
fallthrough
case 6:
k1 ^= block(tail[5]) << 40
fallthrough
case 5:
k1 ^= block(tail[4]) << 32
fallthrough
case 4:
k1 ^= block(tail[3]) << 24
fallthrough
case 3:
k1 ^= block(tail[2]) << 16
fallthrough
case 2:
k1 ^= block(tail[1]) << 8
fallthrough
case 1:
k1 ^= block(tail[0])
k1 *= c1
k1 = rotl(k1, 31)
k1 *= c2
h1 ^= k1
}
h1 ^= int64(length)
h2 ^= int64(length)
h1 += h2
h2 += h1
h1 = fmix(h1)
h2 = fmix(h2)
h1 += h2
// the following is extraneous since h2 is discarded
// h2 += h1
return h1
}

@ -0,0 +1,11 @@
// +build appengine
package murmur
import "encoding/binary"
func getBlock(data []byte, n int) (int64, int64) {
k1 := binary.LittleEndian.Int64(data[n*16:])
k2 := binary.LittleEndian.Int64(data[(n*16)+8:])
return k1, k2
}

@ -0,0 +1,15 @@
// +build !appengine
package murmur
import (
"unsafe"
)
func getBlock(data []byte, n int) (int64, int64) {
block := (*[2]int64)(unsafe.Pointer(&data[n*16]))
k1 := block[0]
k2 := block[1]
return k1, k2
}

@ -0,0 +1,140 @@
package streams
import (
"math"
"strconv"
"sync/atomic"
)
const bucketBits = 64
// IDGenerator tracks and allocates streams which are in use.
type IDGenerator struct {
NumStreams int
inuseStreams int32
numBuckets uint32
// streams is a bitset where each bit represents a stream, a 1 implies in use
streams []uint64
offset uint32
}
func New(protocol int) *IDGenerator {
maxStreams := 128
if protocol > 2 {
maxStreams = 32768
}
buckets := maxStreams / 64
// reserve stream 0
streams := make([]uint64, buckets)
streams[0] = 1 << 63
return &IDGenerator{
NumStreams: maxStreams,
streams: streams,
numBuckets: uint32(buckets),
offset: uint32(buckets) - 1,
}
}
func streamFromBucket(bucket, streamInBucket int) int {
return (bucket * bucketBits) + streamInBucket
}
func (s *IDGenerator) GetStream() (int, bool) {
// based closely on the java-driver stream ID generator
// avoid false sharing subsequent requests.
offset := atomic.LoadUint32(&s.offset)
for !atomic.CompareAndSwapUint32(&s.offset, offset, (offset+1)%s.numBuckets) {
offset = atomic.LoadUint32(&s.offset)
}
offset = (offset + 1) % s.numBuckets
for i := uint32(0); i < s.numBuckets; i++ {
pos := int((i + offset) % s.numBuckets)
bucket := atomic.LoadUint64(&s.streams[pos])
if bucket == math.MaxUint64 {
// all streams in use
continue
}
for j := 0; j < bucketBits; j++ {
mask := uint64(1 << streamOffset(j))
for bucket&mask == 0 {
if atomic.CompareAndSwapUint64(&s.streams[pos], bucket, bucket|mask) {
atomic.AddInt32(&s.inuseStreams, 1)
return streamFromBucket(int(pos), j), true
}
bucket = atomic.LoadUint64(&s.streams[pos])
}
}
}
return 0, false
}
func bitfmt(b uint64) string {
return strconv.FormatUint(b, 16)
}
// returns the bucket offset of a given stream
func bucketOffset(i int) int {
return i / bucketBits
}
func streamOffset(stream int) uint64 {
return bucketBits - uint64(stream%bucketBits) - 1
}
func isSet(bits uint64, stream int) bool {
return bits>>streamOffset(stream)&1 == 1
}
func (s *IDGenerator) isSet(stream int) bool {
bits := atomic.LoadUint64(&s.streams[bucketOffset(stream)])
return isSet(bits, stream)
}
func (s *IDGenerator) String() string {
size := s.numBuckets * (bucketBits + 1)
buf := make([]byte, 0, size)
for i := 0; i < int(s.numBuckets); i++ {
bits := atomic.LoadUint64(&s.streams[i])
buf = append(buf, bitfmt(bits)...)
buf = append(buf, ' ')
}
return string(buf[: size-1 : size-1])
}
func (s *IDGenerator) Clear(stream int) (inuse bool) {
offset := bucketOffset(stream)
bucket := atomic.LoadUint64(&s.streams[offset])
mask := uint64(1) << streamOffset(stream)
if bucket&mask != mask {
// already cleared
return false
}
for !atomic.CompareAndSwapUint64(&s.streams[offset], bucket, bucket & ^mask) {
bucket = atomic.LoadUint64(&s.streams[offset])
if bucket&mask != mask {
// already cleared
return false
}
}
// TODO: make this account for 0 stream being reserved
if atomic.AddInt32(&s.inuseStreams, -1) < 0 {
// TODO(zariel): remove this
panic("negative streams inuse")
}
return true
}
func (s *IDGenerator) Available() int {
return s.NumStreams - int(atomic.LoadInt32(&s.inuseStreams)) - 1
}

@ -0,0 +1,30 @@
package gocql
import (
"bytes"
"fmt"
"log"
)
type StdLogger interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
Println(v ...interface{})
}
type testLogger struct {
capture bytes.Buffer
}
func (l *testLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) }
func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) }
func (l *testLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) }
func (l *testLogger) String() string { return l.capture.String() }
type defaultLogger struct{}
func (l *defaultLogger) Print(v ...interface{}) { log.Print(v...) }
func (l *defaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) }
func (l *defaultLogger) Println(v ...interface{}) { log.Println(v...) }
var Logger StdLogger = &defaultLogger{}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,971 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//This file will be the future home for more policies
package gocql
import (
"context"
crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
"math"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/hailocab/go-hostpool"
)
// cowHostList implements a copy on write host list, its equivalent type is []*HostInfo
type cowHostList struct {
list atomic.Value
mu sync.Mutex
}
func (c *cowHostList) String() string {
return fmt.Sprintf("%+v", c.get())
}
func (c *cowHostList) get() []*HostInfo {
// TODO(zariel): should we replace this with []*HostInfo?
l, ok := c.list.Load().(*[]*HostInfo)
if !ok {
return nil
}
return *l
}
func (c *cowHostList) set(list []*HostInfo) {
c.mu.Lock()
c.list.Store(&list)
c.mu.Unlock()
}
// add will add a host if it not already in the list
func (c *cowHostList) add(host *HostInfo) bool {
c.mu.Lock()
l := c.get()
if n := len(l); n == 0 {
l = []*HostInfo{host}
} else {
newL := make([]*HostInfo, n+1)
for i := 0; i < n; i++ {
if host.Equal(l[i]) {
c.mu.Unlock()
return false
}
newL[i] = l[i]
}
newL[n] = host
l = newL
}
c.list.Store(&l)
c.mu.Unlock()
return true
}
func (c *cowHostList) update(host *HostInfo) {
c.mu.Lock()
l := c.get()
if len(l) == 0 {
c.mu.Unlock()
return
}
found := false
newL := make([]*HostInfo, len(l))
for i := range l {
if host.Equal(l[i]) {
newL[i] = host
found = true
} else {
newL[i] = l[i]
}
}
if found {
c.list.Store(&newL)
}
c.mu.Unlock()
}
func (c *cowHostList) remove(ip net.IP) bool {
c.mu.Lock()
l := c.get()
size := len(l)
if size == 0 {
c.mu.Unlock()
return false
}
found := false
newL := make([]*HostInfo, 0, size)
for i := 0; i < len(l); i++ {
if !l[i].ConnectAddress().Equal(ip) {
newL = append(newL, l[i])
} else {
found = true
}
}
if !found {
c.mu.Unlock()
return false
}
newL = newL[: size-1 : size-1]
c.list.Store(&newL)
c.mu.Unlock()
return true
}
// RetryableQuery is an interface that represents a query or batch statement that
// exposes the correct functions for the retry policy logic to evaluate correctly.
type RetryableQuery interface {
Attempts() int
SetConsistency(c Consistency)
GetConsistency() Consistency
Context() context.Context
}
type RetryType uint16
const (
Retry RetryType = 0x00 // retry on same connection
RetryNextHost RetryType = 0x01 // retry on another connection
Ignore RetryType = 0x02 // ignore error and return result
Rethrow RetryType = 0x03 // raise error and stop retrying
)
// ErrUnknownRetryType is returned if the retry policy returns a retry type
// unknown to the query executor.
var ErrUnknownRetryType = errors.New("unknown retry type returned by retry policy")
// RetryPolicy interface is used by gocql to determine if a query can be attempted
// again after a retryable error has been received. The interface allows gocql
// users to implement their own logic to determine if a query can be attempted
// again.
//
// See SimpleRetryPolicy as an example of implementing and using a RetryPolicy
// interface.
type RetryPolicy interface {
Attempt(RetryableQuery) bool
GetRetryType(error) RetryType
}
// SimpleRetryPolicy has simple logic for attempting a query a fixed number of times.
//
// See below for examples of usage:
//
// //Assign to the cluster
// cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3}
//
// //Assign to a query
// query.RetryPolicy(&gocql.SimpleRetryPolicy{NumRetries: 1})
//
type SimpleRetryPolicy struct {
NumRetries int //Number of times to retry a query
}
// Attempt tells gocql to attempt the query again based on query.Attempts being less
// than the NumRetries defined in the policy.
func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
return q.Attempts() <= s.NumRetries
}
func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType {
return RetryNextHost
}
// ExponentialBackoffRetryPolicy sleeps between attempts
type ExponentialBackoffRetryPolicy struct {
NumRetries int
Min, Max time.Duration
}
func (e *ExponentialBackoffRetryPolicy) Attempt(q RetryableQuery) bool {
if q.Attempts() > e.NumRetries {
return false
}
time.Sleep(e.napTime(q.Attempts()))
return true
}
// used to calculate exponentially growing time
func getExponentialTime(min time.Duration, max time.Duration, attempts int) time.Duration {
if min <= 0 {
min = 100 * time.Millisecond
}
if max <= 0 {
max = 10 * time.Second
}
minFloat := float64(min)
napDuration := minFloat * math.Pow(2, float64(attempts-1))
// add some jitter
napDuration += rand.Float64()*minFloat - (minFloat / 2)
if napDuration > float64(max) {
return time.Duration(max)
}
return time.Duration(napDuration)
}
func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType {
return RetryNextHost
}
// DowngradingConsistencyRetryPolicy: Next retry will be with the next consistency level
// provided in the slice
//
// On a read timeout: the operation is retried with the next provided consistency
// level.
//
// On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH`
// and at least one replica acknowledged the write, the operation is
// retried with the next consistency level. Furthermore, for other
// write types, if at least one replica acknowledged the write, the
// timeout is ignored.
//
// On an unavailable exception: if at least one replica is alive, the
// operation is retried with the next provided consistency level.
type DowngradingConsistencyRetryPolicy struct {
ConsistencyLevelsToTry []Consistency
}
func (d *DowngradingConsistencyRetryPolicy) Attempt(q RetryableQuery) bool {
currentAttempt := q.Attempts()
if currentAttempt > len(d.ConsistencyLevelsToTry) {
return false
} else if currentAttempt > 0 {
q.SetConsistency(d.ConsistencyLevelsToTry[currentAttempt-1])
if gocqlDebug {
Logger.Printf("%T: set consistency to %q\n",
d,
d.ConsistencyLevelsToTry[currentAttempt-1])
}
}
return true
}
func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType {
switch t := err.(type) {
case *RequestErrUnavailable:
if t.Alive > 0 {
return Retry
}
return Rethrow
case *RequestErrWriteTimeout:
if t.WriteType == "SIMPLE" || t.WriteType == "BATCH" || t.WriteType == "COUNTER" {
if t.Received > 0 {
return Ignore
}
return Rethrow
}
if t.WriteType == "UNLOGGED_BATCH" {
return Retry
}
return Rethrow
case *RequestErrReadTimeout:
return Retry
default:
return RetryNextHost
}
}
func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration {
return getExponentialTime(e.Min, e.Max, attempts)
}
type HostStateNotifier interface {
AddHost(host *HostInfo)
RemoveHost(host *HostInfo)
HostUp(host *HostInfo)
HostDown(host *HostInfo)
}
type KeyspaceUpdateEvent struct {
Keyspace string
Change string
}
// HostSelectionPolicy is an interface for selecting
// the most appropriate host to execute a given query.
type HostSelectionPolicy interface {
HostStateNotifier
SetPartitioner
KeyspaceChanged(KeyspaceUpdateEvent)
Init(*Session)
IsLocal(host *HostInfo) bool
//Pick returns an iteration function over selected hosts
Pick(ExecutableQuery) NextHost
}
// SelectedHost is an interface returned when picking a host from a host
// selection policy.
type SelectedHost interface {
Info() *HostInfo
Mark(error)
}
type selectedHost HostInfo
func (host *selectedHost) Info() *HostInfo {
return (*HostInfo)(host)
}
func (host *selectedHost) Mark(err error) {}
// NextHost is an iteration function over picked hosts
type NextHost func() SelectedHost
// RoundRobinHostPolicy is a round-robin load balancing policy, where each host
// is tried sequentially for each query.
func RoundRobinHostPolicy() HostSelectionPolicy {
return &roundRobinHostPolicy{}
}
type roundRobinHostPolicy struct {
hosts cowHostList
}
func (r *roundRobinHostPolicy) IsLocal(*HostInfo) bool { return true }
func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {}
func (r *roundRobinHostPolicy) Init(*Session) {}
func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost {
src := r.hosts.get()
hosts := make([]*HostInfo, len(src))
copy(hosts, src)
rand := rand.New(randSource())
rand.Shuffle(len(hosts), func(i, j int) {
hosts[i], hosts[j] = hosts[j], hosts[i]
})
return roundRobbin(hosts)
}
func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
r.hosts.add(host)
}
func (r *roundRobinHostPolicy) RemoveHost(host *HostInfo) {
r.hosts.remove(host.ConnectAddress())
}
func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
r.AddHost(host)
}
func (r *roundRobinHostPolicy) HostDown(host *HostInfo) {
r.RemoveHost(host)
}
func ShuffleReplicas() func(*tokenAwareHostPolicy) {
return func(t *tokenAwareHostPolicy) {
t.shuffleReplicas = true
}
}
// NonLocalReplicasFallback enables fallback to replicas that are not considered local.
//
// TokenAwareHostPolicy used with DCAwareHostPolicy fallback first selects replicas by partition key in local DC, then
// falls back to other nodes in the local DC. Enabling NonLocalReplicasFallback causes TokenAwareHostPolicy
// to first select replicas by partition key in local DC, then replicas by partition key in remote DCs and fall back
// to other nodes in local DC.
func NonLocalReplicasFallback() func(policy *tokenAwareHostPolicy) {
return func(t *tokenAwareHostPolicy) {
t.nonLocalReplicasFallback = true
}
}
// TokenAwareHostPolicy is a token aware host selection policy, where hosts are
// selected based on the partition key, so queries are sent to the host which
// owns the partition. Fallback is used when routing information is not available.
func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAwareHostPolicy)) HostSelectionPolicy {
p := &tokenAwareHostPolicy{fallback: fallback}
for _, opt := range opts {
opt(p)
}
return p
}
// clusterMeta holds metadata about cluster topology.
// It is used inside atomic.Value and shallow copies are used when replacing it,
// so fields should not be modified in-place. Instead, to modify a field a copy of the field should be made
// and the pointer in clusterMeta updated to point to the new value.
type clusterMeta struct {
// replicas is map[keyspace]map[token]hosts
replicas map[string]tokenRingReplicas
tokenRing *tokenRing
}
type tokenAwareHostPolicy struct {
fallback HostSelectionPolicy
getKeyspaceMetadata func(keyspace string) (*KeyspaceMetadata, error)
getKeyspaceName func() string
shuffleReplicas bool
nonLocalReplicasFallback bool
// mu protects writes to hosts, partitioner, metadata.
// reads can be unlocked as long as they are not used for updating state later.
mu sync.Mutex
hosts cowHostList
partitioner string
metadata atomic.Value // *clusterMeta
}
func (t *tokenAwareHostPolicy) Init(s *Session) {
t.getKeyspaceMetadata = s.KeyspaceMetadata
t.getKeyspaceName = func() string { return s.cfg.Keyspace }
}
func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
return t.fallback.IsLocal(host)
}
func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) {
t.mu.Lock()
defer t.mu.Unlock()
meta := t.getMetadataForUpdate()
t.updateReplicas(meta, update.Keyspace)
t.metadata.Store(meta)
}
// updateReplicas updates replicas in clusterMeta.
// It must be called with t.mu mutex locked.
// meta must not be nil and it's replicas field will be updated.
func (t *tokenAwareHostPolicy) updateReplicas(meta *clusterMeta, keyspace string) {
newReplicas := make(map[string]tokenRingReplicas, len(meta.replicas))
ks, err := t.getKeyspaceMetadata(keyspace)
if err == nil {
strat := getStrategy(ks)
if strat != nil {
if meta != nil && meta.tokenRing != nil {
newReplicas[keyspace] = strat.replicaMap(meta.tokenRing)
}
}
}
for ks, replicas := range meta.replicas {
if ks != keyspace {
newReplicas[ks] = replicas
}
}
meta.replicas = newReplicas
}
func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
t.mu.Lock()
defer t.mu.Unlock()
if t.partitioner != partitioner {
t.fallback.SetPartitioner(partitioner)
t.partitioner = partitioner
meta := t.getMetadataForUpdate()
meta.resetTokenRing(t.partitioner, t.hosts.get())
t.updateReplicas(meta, t.getKeyspaceName())
t.metadata.Store(meta)
}
}
func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
t.mu.Lock()
if t.hosts.add(host) {
meta := t.getMetadataForUpdate()
meta.resetTokenRing(t.partitioner, t.hosts.get())
t.updateReplicas(meta, t.getKeyspaceName())
t.metadata.Store(meta)
}
t.mu.Unlock()
t.fallback.AddHost(host)
}
func (t *tokenAwareHostPolicy) AddHosts(hosts []*HostInfo) {
t.mu.Lock()
for _, host := range hosts {
t.hosts.add(host)
}
meta := t.getMetadataForUpdate()
meta.resetTokenRing(t.partitioner, t.hosts.get())
t.updateReplicas(meta, t.getKeyspaceName())
t.metadata.Store(meta)
t.mu.Unlock()
for _, host := range hosts {
t.fallback.AddHost(host)
}
}
func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
t.mu.Lock()
if t.hosts.remove(host.ConnectAddress()) {
meta := t.getMetadataForUpdate()
meta.resetTokenRing(t.partitioner, t.hosts.get())
t.updateReplicas(meta, t.getKeyspaceName())
t.metadata.Store(meta)
}
t.mu.Unlock()
t.fallback.RemoveHost(host)
}
func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
t.fallback.HostUp(host)
}
func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
t.fallback.HostDown(host)
}
// getMetadataReadOnly returns current cluster metadata.
// Metadata uses copy on write, so the returned value should be only used for reading.
// To obtain a copy that could be updated, use getMetadataForUpdate instead.
func (t *tokenAwareHostPolicy) getMetadataReadOnly() *clusterMeta {
meta, _ := t.metadata.Load().(*clusterMeta)
return meta
}
// getMetadataForUpdate returns clusterMeta suitable for updating.
// It is a SHALLOW copy of current metadata in case it was already set or new empty clusterMeta otherwise.
// This function should be called with t.mu mutex locked and the mutex should not be released before
// storing the new metadata.
func (t *tokenAwareHostPolicy) getMetadataForUpdate() *clusterMeta {
metaReadOnly := t.getMetadataReadOnly()
meta := new(clusterMeta)
if metaReadOnly != nil {
*meta = *metaReadOnly
}
return meta
}
// resetTokenRing creates a new tokenRing.
// It must be called with t.mu locked.
func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo) {
if partitioner == "" {
// partitioner not yet set
return
}
// create a new token ring
tokenRing, err := newTokenRing(partitioner, hosts)
if err != nil {
Logger.Printf("Unable to update the token ring due to error: %s", err)
return
}
// replace the token ring
m.tokenRing = tokenRing
}
func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
if qry == nil {
return t.fallback.Pick(qry)
}
routingKey, err := qry.GetRoutingKey()
if err != nil {
return t.fallback.Pick(qry)
} else if routingKey == nil {
return t.fallback.Pick(qry)
}
meta := t.getMetadataReadOnly()
if meta == nil || meta.tokenRing == nil {
return t.fallback.Pick(qry)
}
token := meta.tokenRing.partitioner.Hash(routingKey)
ht := meta.replicas[qry.Keyspace()].replicasFor(token)
var replicas []*HostInfo
if ht == nil {
host, _ := meta.tokenRing.GetHostForToken(token)
replicas = []*HostInfo{host}
} else if t.shuffleReplicas {
replicas = shuffleHosts(replicas)
} else {
replicas = ht.hosts
}
var (
fallbackIter NextHost
i, j int
remote []*HostInfo
)
used := make(map[*HostInfo]bool, len(replicas))
return func() SelectedHost {
for i < len(replicas) {
h := replicas[i]
i++
if !t.fallback.IsLocal(h) {
remote = append(remote, h)
continue
}
if h.IsUp() {
used[h] = true
return (*selectedHost)(h)
}
}
if t.nonLocalReplicasFallback {
for j < len(remote) {
h := remote[j]
j++
if h.IsUp() {
used[h] = true
return (*selectedHost)(h)
}
}
}
if fallbackIter == nil {
// fallback
fallbackIter = t.fallback.Pick(qry)
}
// filter the token aware selected hosts from the fallback hosts
for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() {
if !used[fallbackHost.Info()] {
used[fallbackHost.Info()] = true
return fallbackHost
}
}
return nil
}
}
// HostPoolHostPolicy is a host policy which uses the bitly/go-hostpool library
// to distribute queries between hosts and prevent sending queries to
// unresponsive hosts. When creating the host pool that is passed to the policy
// use an empty slice of hosts as the hostpool will be populated later by gocql.
// See below for examples of usage:
//
// // Create host selection policy using a simple host pool
// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(hostpool.New(nil))
//
// // Create host selection policy using an epsilon greedy pool
// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(
// hostpool.NewEpsilonGreedy(nil, 0, &hostpool.LinearEpsilonValueCalculator{}),
// )
//
func HostPoolHostPolicy(hp hostpool.HostPool) HostSelectionPolicy {
return &hostPoolHostPolicy{hostMap: map[string]*HostInfo{}, hp: hp}
}
type hostPoolHostPolicy struct {
hp hostpool.HostPool
mu sync.RWMutex
hostMap map[string]*HostInfo
}
func (r *hostPoolHostPolicy) Init(*Session) {}
func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *hostPoolHostPolicy) SetPartitioner(string) {}
func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true }
func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
peers := make([]string, len(hosts))
hostMap := make(map[string]*HostInfo, len(hosts))
for i, host := range hosts {
ip := host.ConnectAddress().String()
peers[i] = ip
hostMap[ip] = host
}
r.mu.Lock()
r.hp.SetHosts(peers)
r.hostMap = hostMap
r.mu.Unlock()
}
func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
ip := host.ConnectAddress().String()
r.mu.Lock()
defer r.mu.Unlock()
// If the host addr is present and isn't nil return
if h, ok := r.hostMap[ip]; ok && h != nil {
return
}
// otherwise, add the host to the map
r.hostMap[ip] = host
// and construct a new peer list to give to the HostPool
hosts := make([]string, 0, len(r.hostMap))
for addr := range r.hostMap {
hosts = append(hosts, addr)
}
r.hp.SetHosts(hosts)
}
func (r *hostPoolHostPolicy) RemoveHost(host *HostInfo) {
ip := host.ConnectAddress().String()
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.hostMap[ip]; !ok {
return
}
delete(r.hostMap, ip)
hosts := make([]string, 0, len(r.hostMap))
for _, host := range r.hostMap {
hosts = append(hosts, host.ConnectAddress().String())
}
r.hp.SetHosts(hosts)
}
func (r *hostPoolHostPolicy) HostUp(host *HostInfo) {
r.AddHost(host)
}
func (r *hostPoolHostPolicy) HostDown(host *HostInfo) {
r.RemoveHost(host)
}
func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost {
return func() SelectedHost {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.hostMap) == 0 {
return nil
}
hostR := r.hp.Get()
host, ok := r.hostMap[hostR.Host()]
if !ok {
return nil
}
return selectedHostPoolHost{
policy: r,
info: host,
hostR: hostR,
}
}
}
// selectedHostPoolHost is a host returned by the hostPoolHostPolicy and
// implements the SelectedHost interface
type selectedHostPoolHost struct {
policy *hostPoolHostPolicy
info *HostInfo
hostR hostpool.HostPoolResponse
}
func (host selectedHostPoolHost) Info() *HostInfo {
return host.info
}
func (host selectedHostPoolHost) Mark(err error) {
ip := host.info.ConnectAddress().String()
host.policy.mu.RLock()
defer host.policy.mu.RUnlock()
if _, ok := host.policy.hostMap[ip]; !ok {
// host was removed between pick and mark
return
}
host.hostR.Mark(err)
}
type dcAwareRR struct {
local string
localHosts cowHostList
remoteHosts cowHostList
}
// DCAwareRoundRobinPolicy is a host selection policies which will prioritize and
// return hosts which are in the local datacentre before returning hosts in all
// other datercentres
func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy {
return &dcAwareRR{local: localDC}
}
func (d *dcAwareRR) Init(*Session) {}
func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *dcAwareRR) SetPartitioner(p string) {}
func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
return host.DataCenter() == d.local
}
func (d *dcAwareRR) AddHost(host *HostInfo) {
if d.IsLocal(host) {
d.localHosts.add(host)
} else {
d.remoteHosts.add(host)
}
}
func (d *dcAwareRR) RemoveHost(host *HostInfo) {
if d.IsLocal(host) {
d.localHosts.remove(host.ConnectAddress())
} else {
d.remoteHosts.remove(host.ConnectAddress())
}
}
func (d *dcAwareRR) HostUp(host *HostInfo) { d.AddHost(host) }
func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) }
var randSeed int64
func init() {
p := make([]byte, 8)
if _, err := crand.Read(p); err != nil {
panic(err)
}
randSeed = int64(binary.BigEndian.Uint64(p))
}
func randSource() rand.Source {
return rand.NewSource(atomic.AddInt64(&randSeed, 1))
}
func roundRobbin(hosts []*HostInfo) NextHost {
var i int
return func() SelectedHost {
for i < len(hosts) {
h := hosts[i]
i++
if h.IsUp() {
return (*selectedHost)(h)
}
}
return nil
}
}
func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
local := d.localHosts.get()
remote := d.remoteHosts.get()
hosts := make([]*HostInfo, len(local)+len(remote))
n := copy(hosts, local)
copy(hosts[n:], remote)
// TODO: use random chose-2 but that will require plumbing information
// about connection/host load to here
r := rand.New(randSource())
for _, l := range [][]*HostInfo{hosts[:len(local)], hosts[len(local):]} {
r.Shuffle(len(l), func(i, j int) {
l[i], l[j] = l[j], l[i]
})
}
return roundRobbin(hosts)
}
// ConvictionPolicy interface is used by gocql to determine if a host should be
// marked as DOWN based on the error and host info
type ConvictionPolicy interface {
// Implementations should return `true` if the host should be convicted, `false` otherwise.
AddFailure(error error, host *HostInfo) bool
//Implementations should clear out any convictions or state regarding the host.
Reset(host *HostInfo)
}
// SimpleConvictionPolicy implements a ConvictionPolicy which convicts all hosts
// regardless of error
type SimpleConvictionPolicy struct {
}
func (e *SimpleConvictionPolicy) AddFailure(error error, host *HostInfo) bool {
return true
}
func (e *SimpleConvictionPolicy) Reset(host *HostInfo) {}
// ReconnectionPolicy interface is used by gocql to determine if reconnection
// can be attempted after connection error. The interface allows gocql users
// to implement their own logic to determine how to attempt reconnection.
//
type ReconnectionPolicy interface {
GetInterval(currentRetry int) time.Duration
GetMaxRetries() int
}
// ConstantReconnectionPolicy has simple logic for returning a fixed reconnection interval.
//
// Examples of usage:
//
// cluster.ReconnectionPolicy = &gocql.ConstantReconnectionPolicy{MaxRetries: 10, Interval: 8 * time.Second}
//
type ConstantReconnectionPolicy struct {
MaxRetries int
Interval time.Duration
}
func (c *ConstantReconnectionPolicy) GetInterval(currentRetry int) time.Duration {
return c.Interval
}
func (c *ConstantReconnectionPolicy) GetMaxRetries() int {
return c.MaxRetries
}
// ExponentialReconnectionPolicy returns a growing reconnection interval.
type ExponentialReconnectionPolicy struct {
MaxRetries int
InitialInterval time.Duration
}
func (e *ExponentialReconnectionPolicy) GetInterval(currentRetry int) time.Duration {
return getExponentialTime(e.InitialInterval, math.MaxInt16*time.Second, currentRetry)
}
func (e *ExponentialReconnectionPolicy) GetMaxRetries() int {
return e.MaxRetries
}
type SpeculativeExecutionPolicy interface {
Attempts() int
Delay() time.Duration
}
type NonSpeculativeExecution struct{}
func (sp NonSpeculativeExecution) Attempts() int { return 0 } // No additional attempts
func (sp NonSpeculativeExecution) Delay() time.Duration { return 1 } // The delay. Must be positive to be used in a ticker.
type SimpleSpeculativeExecution struct {
NumAttempts int
TimeoutDelay time.Duration
}
func (sp *SimpleSpeculativeExecution) Attempts() int { return sp.NumAttempts }
func (sp *SimpleSpeculativeExecution) Delay() time.Duration { return sp.TimeoutDelay }

@ -0,0 +1,64 @@
package gocql
import (
"github.com/gocql/gocql/internal/lru"
"sync"
)
const defaultMaxPreparedStmts = 1000
// preparedLRU is the prepared statement cache
type preparedLRU struct {
mu sync.Mutex
lru *lru.Cache
}
// Max adjusts the maximum size of the cache and cleans up the oldest records if
// the new max is lower than the previous value. Not concurrency safe.
func (p *preparedLRU) max(max int) {
p.mu.Lock()
defer p.mu.Unlock()
for p.lru.Len() > max {
p.lru.RemoveOldest()
}
p.lru.MaxEntries = max
}
func (p *preparedLRU) clear() {
p.mu.Lock()
defer p.mu.Unlock()
for p.lru.Len() > 0 {
p.lru.RemoveOldest()
}
}
func (p *preparedLRU) add(key string, val *inflightPrepare) {
p.mu.Lock()
defer p.mu.Unlock()
p.lru.Add(key, val)
}
func (p *preparedLRU) remove(key string) bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.lru.Remove(key)
}
func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
p.mu.Lock()
defer p.mu.Unlock()
val, ok := p.lru.Get(key)
if ok {
return val.(*inflightPrepare), true
}
return fn(p.lru), false
}
func (p *preparedLRU) keyFor(addr, keyspace, statement string) string {
// TODO: maybe use []byte for keys?
return addr + keyspace + statement
}

@ -0,0 +1,161 @@
package gocql
import (
"context"
"time"
)
type ExecutableQuery interface {
execute(ctx context.Context, conn *Conn) *Iter
attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
retryPolicy() RetryPolicy
speculativeExecutionPolicy() SpeculativeExecutionPolicy
GetRoutingKey() ([]byte, error)
Keyspace() string
IsIdempotent() bool
withContext(context.Context) ExecutableQuery
RetryableQuery
}
type queryExecutor struct {
pool *policyConnPool
policy HostSelectionPolicy
}
func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter {
start := time.Now()
iter := qry.execute(ctx, conn)
end := time.Now()
qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
return iter
}
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, results chan *Iter) *Iter {
ticker := time.NewTicker(sp.Delay())
defer ticker.Stop()
for i := 0; i < sp.Attempts(); i++ {
select {
case <-ticker.C:
go q.run(ctx, qry, results)
case <-ctx.Done():
return &Iter{err: ctx.Err()}
case iter := <-results:
return iter
}
}
return nil
}
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
// check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative
sp := qry.speculativeExecutionPolicy()
if !qry.IsIdempotent() || sp.Attempts() == 0 {
return q.do(qry.Context(), qry), nil
}
ctx, cancel := context.WithCancel(qry.Context())
defer cancel()
results := make(chan *Iter, 1)
// Launch the main execution
go q.run(ctx, qry, results)
// The speculative executions are launched _in addition_ to the main
// execution, on a timer. So Speculation{2} would make 3 executions running
// in total.
if iter := q.speculate(ctx, qry, sp, results); iter != nil {
return iter, nil
}
select {
case iter := <-results:
return iter, nil
case <-ctx.Done():
return &Iter{err: ctx.Err()}, nil
}
}
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
hostIter := q.policy.Pick(qry)
selectedHost := hostIter()
rt := qry.retryPolicy()
var lastErr error
var iter *Iter
for selectedHost != nil {
host := selectedHost.Info()
if host == nil || !host.IsUp() {
selectedHost = hostIter()
continue
}
pool, ok := q.pool.getPool(host)
if !ok {
selectedHost = hostIter()
continue
}
conn := pool.Pick()
if conn == nil {
selectedHost = hostIter()
continue
}
iter = q.attemptQuery(ctx, qry, conn)
iter.host = selectedHost.Info()
// Update host
switch iter.err {
case context.Canceled, context.DeadlineExceeded, ErrNotFound:
// those errors represents logical errors, they should not count
// toward removing a node from the pool
selectedHost.Mark(nil)
return iter
default:
selectedHost.Mark(iter.err)
}
// Exit if the query was successful
// or no retry policy defined or retry attempts were reached
if iter.err == nil || rt == nil || !rt.Attempt(qry) {
return iter
}
lastErr = iter.err
// If query is unsuccessful, check the error with RetryPolicy to retry
switch rt.GetRetryType(iter.err) {
case Retry:
// retry on the same host
continue
case Rethrow, Ignore:
return iter
case RetryNextHost:
// retry on the next host
selectedHost = hostIter()
continue
default:
// Undefined? Return nil and error, this will panic in the requester
return &Iter{err: ErrUnknownRetryType}
}
}
if lastErr != nil {
return &Iter{err: lastErr}
}
return &Iter{err: ErrNoConnections}
}
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan<- *Iter) {
select {
case results <- q.do(ctx, qry):
case <-ctx.Done():
}
}

@ -0,0 +1,152 @@
package gocql
import (
"fmt"
"net"
"sync"
"sync/atomic"
)
type ring struct {
// endpoints are the set of endpoints which the driver will attempt to connect
// to in the case it can not reach any of its hosts. They are also used to boot
// strap the initial connection.
endpoints []*HostInfo
// hosts are the set of all hosts in the cassandra ring that we know of
mu sync.RWMutex
hosts map[string]*HostInfo
hostList []*HostInfo
pos uint32
// TODO: we should store the ring metadata here also.
}
func (r *ring) rrHost() *HostInfo {
// TODO: should we filter hosts that get used here? These hosts will be used
// for the control connection, should we also provide an iterator?
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.hostList) == 0 {
return nil
}
pos := int(atomic.AddUint32(&r.pos, 1) - 1)
return r.hostList[pos%len(r.hostList)]
}
func (r *ring) getHost(ip net.IP) *HostInfo {
r.mu.RLock()
host := r.hosts[ip.String()]
r.mu.RUnlock()
return host
}
func (r *ring) allHosts() []*HostInfo {
r.mu.RLock()
hosts := make([]*HostInfo, 0, len(r.hosts))
for _, host := range r.hosts {
hosts = append(hosts, host)
}
r.mu.RUnlock()
return hosts
}
func (r *ring) currentHosts() map[string]*HostInfo {
r.mu.RLock()
hosts := make(map[string]*HostInfo, len(r.hosts))
for k, v := range r.hosts {
hosts[k] = v
}
r.mu.RUnlock()
return hosts
}
func (r *ring) addHost(host *HostInfo) bool {
// TODO(zariel): key all host info by HostID instead of
// ip addresses
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
ip := host.ConnectAddress().String()
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
_, ok := r.hosts[ip]
if !ok {
r.hostList = append(r.hostList, host)
}
r.hosts[ip] = host
r.mu.Unlock()
return ok
}
func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
if existingHost, ok := r.addHostIfMissing(host); ok {
existingHost.update(host)
host = existingHost
}
return host
}
func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
ip := host.ConnectAddress().String()
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
existing, ok := r.hosts[ip]
if !ok {
r.hosts[ip] = host
existing = host
r.hostList = append(r.hostList, host)
}
r.mu.Unlock()
return existing, ok
}
func (r *ring) removeHost(ip net.IP) bool {
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
k := ip.String()
_, ok := r.hosts[k]
if ok {
for i, host := range r.hostList {
if host.ConnectAddress().Equal(ip) {
r.hostList = append(r.hostList[:i], r.hostList[i+1:]...)
break
}
}
}
delete(r.hosts, k)
r.mu.Unlock()
return ok
}
type clusterMetadata struct {
mu sync.RWMutex
partitioner string
}
func (c *clusterMetadata) setPartitioner(partitioner string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.partitioner != partitioner {
// TODO: update other things now
c.partitioner = partitioner
}
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,223 @@
// Copyright (c) 2015 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"bytes"
"crypto/md5"
"fmt"
"math/big"
"sort"
"strconv"
"strings"
"github.com/gocql/gocql/internal/murmur"
)
// a token partitioner
type partitioner interface {
Name() string
Hash([]byte) token
ParseString(string) token
}
// a token
type token interface {
fmt.Stringer
Less(token) bool
}
// murmur3 partitioner and token
type murmur3Partitioner struct{}
type murmur3Token int64
func (p murmur3Partitioner) Name() string {
return "Murmur3Partitioner"
}
func (p murmur3Partitioner) Hash(partitionKey []byte) token {
h1 := murmur.Murmur3H1(partitionKey)
return murmur3Token(h1)
}
// murmur3 little-endian, 128-bit hash, but returns only h1
func (p murmur3Partitioner) ParseString(str string) token {
val, _ := strconv.ParseInt(str, 10, 64)
return murmur3Token(val)
}
func (m murmur3Token) String() string {
return strconv.FormatInt(int64(m), 10)
}
func (m murmur3Token) Less(token token) bool {
return m < token.(murmur3Token)
}
// order preserving partitioner and token
type orderedPartitioner struct{}
type orderedToken string
func (p orderedPartitioner) Name() string {
return "OrderedPartitioner"
}
func (p orderedPartitioner) Hash(partitionKey []byte) token {
// the partition key is the token
return orderedToken(partitionKey)
}
func (p orderedPartitioner) ParseString(str string) token {
return orderedToken(str)
}
func (o orderedToken) String() string {
return string(o)
}
func (o orderedToken) Less(token token) bool {
return o < token.(orderedToken)
}
// random partitioner and token
type randomPartitioner struct{}
type randomToken big.Int
func (r randomPartitioner) Name() string {
return "RandomPartitioner"
}
// 2 ** 128
var maxHashInt, _ = new(big.Int).SetString("340282366920938463463374607431768211456", 10)
func (p randomPartitioner) Hash(partitionKey []byte) token {
sum := md5.Sum(partitionKey)
val := new(big.Int)
val.SetBytes(sum[:])
if sum[0] > 127 {
val.Sub(val, maxHashInt)
val.Abs(val)
}
return (*randomToken)(val)
}
func (p randomPartitioner) ParseString(str string) token {
val := new(big.Int)
val.SetString(str, 10)
return (*randomToken)(val)
}
func (r *randomToken) String() string {
return (*big.Int)(r).String()
}
func (r *randomToken) Less(token token) bool {
return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken)))
}
type hostToken struct {
token token
host *HostInfo
}
func (ht hostToken) String() string {
return fmt.Sprintf("{token=%v host=%v}", ht.token, ht.host.HostID())
}
// a data structure for organizing the relationship between tokens and hosts
type tokenRing struct {
partitioner partitioner
tokens []hostToken
hosts []*HostInfo
}
func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
tokenRing := &tokenRing{
hosts: hosts,
}
if strings.HasSuffix(partitioner, "Murmur3Partitioner") {
tokenRing.partitioner = murmur3Partitioner{}
} else if strings.HasSuffix(partitioner, "OrderedPartitioner") {
tokenRing.partitioner = orderedPartitioner{}
} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
tokenRing.partitioner = randomPartitioner{}
} else {
return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
}
for _, host := range hosts {
for _, strToken := range host.Tokens() {
token := tokenRing.partitioner.ParseString(strToken)
tokenRing.tokens = append(tokenRing.tokens, hostToken{token, host})
}
}
sort.Sort(tokenRing)
return tokenRing, nil
}
func (t *tokenRing) Len() int {
return len(t.tokens)
}
func (t *tokenRing) Less(i, j int) bool {
return t.tokens[i].token.Less(t.tokens[j].token)
}
func (t *tokenRing) Swap(i, j int) {
t.tokens[i], t.tokens[j] = t.tokens[j], t.tokens[i]
}
func (t *tokenRing) String() string {
buf := &bytes.Buffer{}
buf.WriteString("TokenRing(")
if t.partitioner != nil {
buf.WriteString(t.partitioner.Name())
}
buf.WriteString("){")
sep := ""
for i, th := range t.tokens {
buf.WriteString(sep)
sep = ","
buf.WriteString("\n\t[")
buf.WriteString(strconv.Itoa(i))
buf.WriteString("]")
buf.WriteString(th.token.String())
buf.WriteString(":")
buf.WriteString(th.host.ConnectAddress().String())
}
buf.WriteString("\n}")
return string(buf.Bytes())
}
func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) (host *HostInfo, endToken token) {
if t == nil {
return nil, nil
}
return t.GetHostForToken(t.partitioner.Hash(partitionKey))
}
func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token) {
if t == nil || len(t.tokens) == 0 {
return nil, nil
}
// find the primary replica
p := sort.Search(len(t.tokens), func(i int) bool {
return !t.tokens[i].token.Less(token)
})
if p == len(t.tokens) {
// wrap around to the first in the ring
p = 0
}
v := t.tokens[p]
return v.host, v.token
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save