mirror of https://github.com/sunface/rust-course
parent
f9a0b0696d
commit
594f806d59
@ -0,0 +1,16 @@
|
|||||||
|
module github.com/go-rust/im.dev
|
||||||
|
|
||||||
|
go 1.13
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect
|
||||||
|
github.com/gocql/gocql v0.0.0-20200103014340-68f928edb90a
|
||||||
|
github.com/labstack/echo v3.3.10+incompatible
|
||||||
|
github.com/labstack/gommon v0.3.0 // indirect
|
||||||
|
github.com/microcosm-cc/bluemonday v1.0.2
|
||||||
|
github.com/sony/sonyflake v1.0.0
|
||||||
|
github.com/spf13/cobra v0.0.5
|
||||||
|
github.com/valyala/fasttemplate v1.1.0 // indirect
|
||||||
|
go.uber.org/zap v1.13.0
|
||||||
|
gopkg.in/yaml.v2 v2.2.7
|
||||||
|
)
|
@ -0,0 +1,124 @@
|
|||||||
|
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||||
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
||||||
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
||||||
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||||
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||||
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||||
|
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||||
|
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||||
|
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||||
|
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/deckarep/golang-set v1.7.1 h1:SCQV0S6gTtp6itiFrTqI+pfmJ4LN85S1YzhDf9rTHJQ=
|
||||||
|
github.com/deckarep/golang-set v1.7.1/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ=
|
||||||
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||||
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||||
|
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||||
|
github.com/gocql/gocql v0.0.0-20200103014340-68f928edb90a h1:f/7VP2EmdQagG92I/75YnM3ZIeCYa61BT2kZoJFptHM=
|
||||||
|
github.com/gocql/gocql v0.0.0-20200103014340-68f928edb90a/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
|
||||||
|
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
|
||||||
|
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
|
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||||
|
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||||
|
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
||||||
|
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||||
|
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
|
||||||
|
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||||
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||||
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/labstack/echo v3.3.10+incompatible h1:pGRcYk231ExFAyoAjAfD85kQzRJCRI8bbnE7CX5OEgg=
|
||||||
|
github.com/labstack/echo v3.3.10+incompatible/go.mod h1:0INS7j/VjnFxD4E2wkz67b8cVwCLbBmJyDaka6Cmk1s=
|
||||||
|
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
|
||||||
|
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
|
||||||
|
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||||
|
github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU=
|
||||||
|
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||||
|
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||||
|
github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg=
|
||||||
|
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||||
|
github.com/microcosm-cc/bluemonday v1.0.2 h1:5lPfLTTAvAbtS0VqT+94yOtFnGfUWYyx0+iToC3Os3s=
|
||||||
|
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||||
|
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||||
|
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||||
|
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||||
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
|
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||||
|
github.com/sony/sonyflake v1.0.0 h1:MpU6Ro7tfXwgn2l5eluf9xQvQJDROTBImNCfRXn/YeM=
|
||||||
|
github.com/sony/sonyflake v1.0.0/go.mod h1:Jv3cfhf/UFtolOTTRd3q4Nl6ENqM+KfyZ5PseKfZGF4=
|
||||||
|
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
|
||||||
|
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||||
|
github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s=
|
||||||
|
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
|
||||||
|
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
|
||||||
|
github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg=
|
||||||
|
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
|
||||||
|
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
|
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
||||||
|
github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4=
|
||||||
|
github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
||||||
|
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
|
||||||
|
go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY=
|
||||||
|
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||||
|
go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc=
|
||||||
|
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
|
||||||
|
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
|
||||||
|
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
|
||||||
|
go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU=
|
||||||
|
go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
|
||||||
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529 h1:iMGN4xG0cnqj3t+zOM8wUB0BiPKHEwSxEZCvzcbZuvk=
|
||||||
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
|
||||||
|
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||||
|
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||||
|
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ=
|
||||||
|
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
|
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||||
|
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs=
|
||||||
|
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||||
|
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||||
|
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
||||||
|
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
|
||||||
|
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
@ -1,12 +0,0 @@
|
|||||||
import { action, observable } from 'mobx'
|
|
||||||
|
|
||||||
class User{
|
|
||||||
@observable info = new Map()
|
|
||||||
|
|
||||||
@action
|
|
||||||
setInfo = (info) => {
|
|
||||||
this.info.replace(info)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export default new User()
|
|
@ -0,0 +1,18 @@
|
|||||||
|
import { action, observable } from 'mobx'
|
||||||
|
|
||||||
|
type UserInfo = {
|
||||||
|
address: string,
|
||||||
|
email: string,
|
||||||
|
tel: string,
|
||||||
|
avatar: string
|
||||||
|
}
|
||||||
|
class User{
|
||||||
|
@observable info:UserInfo = {address:'',email:'',tel:'',avatar:''}
|
||||||
|
|
||||||
|
@action
|
||||||
|
setInfo = (info:UserInfo) => {
|
||||||
|
this.info = info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default new User()
|
@ -0,0 +1,5 @@
|
|||||||
|
TAGS
|
||||||
|
tags
|
||||||
|
.*.swp
|
||||||
|
tomlcheck/tomlcheck
|
||||||
|
toml.test
|
@ -0,0 +1,15 @@
|
|||||||
|
language: go
|
||||||
|
go:
|
||||||
|
- 1.1
|
||||||
|
- 1.2
|
||||||
|
- 1.3
|
||||||
|
- 1.4
|
||||||
|
- 1.5
|
||||||
|
- 1.6
|
||||||
|
- tip
|
||||||
|
install:
|
||||||
|
- go install ./...
|
||||||
|
- go get github.com/BurntSushi/toml-test
|
||||||
|
script:
|
||||||
|
- export PATH="$PATH:$HOME/gopath/bin"
|
||||||
|
- make test
|
@ -0,0 +1,3 @@
|
|||||||
|
Compatible with TOML version
|
||||||
|
[v0.4.0](https://github.com/toml-lang/toml/blob/v0.4.0/versions/en/toml-v0.4.0.md)
|
||||||
|
|
@ -0,0 +1,21 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2013 TOML authors
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
@ -0,0 +1,19 @@
|
|||||||
|
install:
|
||||||
|
go install ./...
|
||||||
|
|
||||||
|
test: install
|
||||||
|
go test -v
|
||||||
|
toml-test toml-test-decoder
|
||||||
|
toml-test -encoder toml-test-encoder
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
gofmt -w *.go */*.go
|
||||||
|
colcheck *.go */*.go
|
||||||
|
|
||||||
|
tags:
|
||||||
|
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
|
||||||
|
|
||||||
|
push:
|
||||||
|
git push origin master
|
||||||
|
git push github master
|
||||||
|
|
@ -0,0 +1,218 @@
|
|||||||
|
## TOML parser and encoder for Go with reflection
|
||||||
|
|
||||||
|
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
|
||||||
|
reflection interface similar to Go's standard library `json` and `xml`
|
||||||
|
packages. This package also supports the `encoding.TextUnmarshaler` and
|
||||||
|
`encoding.TextMarshaler` interfaces so that you can define custom data
|
||||||
|
representations. (There is an example of this below.)
|
||||||
|
|
||||||
|
Spec: https://github.com/toml-lang/toml
|
||||||
|
|
||||||
|
Compatible with TOML version
|
||||||
|
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
|
||||||
|
|
||||||
|
Documentation: https://godoc.org/github.com/BurntSushi/toml
|
||||||
|
|
||||||
|
Installation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get github.com/BurntSushi/toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Try the toml validator:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get github.com/BurntSushi/toml/cmd/tomlv
|
||||||
|
tomlv some-toml-file.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
[](https://travis-ci.org/BurntSushi/toml) [](https://godoc.org/github.com/BurntSushi/toml)
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
This package passes all tests in
|
||||||
|
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
|
||||||
|
and the encoder.
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
This package works similarly to how the Go standard library handles `XML`
|
||||||
|
and `JSON`. Namely, data is loaded into Go values via reflection.
|
||||||
|
|
||||||
|
For the simplest example, consider some TOML file as just a list of keys
|
||||||
|
and values:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
Age = 25
|
||||||
|
Cats = [ "Cauchy", "Plato" ]
|
||||||
|
Pi = 3.14
|
||||||
|
Perfection = [ 6, 28, 496, 8128 ]
|
||||||
|
DOB = 1987-07-05T05:45:00Z
|
||||||
|
```
|
||||||
|
|
||||||
|
Which could be defined in Go as:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Config struct {
|
||||||
|
Age int
|
||||||
|
Cats []string
|
||||||
|
Pi float64
|
||||||
|
Perfection []int
|
||||||
|
DOB time.Time // requires `import time`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
And then decoded with:
|
||||||
|
|
||||||
|
```go
|
||||||
|
var conf Config
|
||||||
|
if _, err := toml.Decode(tomlData, &conf); err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also use struct tags if your struct field name doesn't map to a TOML
|
||||||
|
key value directly:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
some_key_NAME = "wat"
|
||||||
|
```
|
||||||
|
|
||||||
|
```go
|
||||||
|
type TOML struct {
|
||||||
|
ObscureKey string `toml:"some_key_NAME"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using the `encoding.TextUnmarshaler` interface
|
||||||
|
|
||||||
|
Here's an example that automatically parses duration strings into
|
||||||
|
`time.Duration` values:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[song]]
|
||||||
|
name = "Thunder Road"
|
||||||
|
duration = "4m49s"
|
||||||
|
|
||||||
|
[[song]]
|
||||||
|
name = "Stairway to Heaven"
|
||||||
|
duration = "8m03s"
|
||||||
|
```
|
||||||
|
|
||||||
|
Which can be decoded with:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type song struct {
|
||||||
|
Name string
|
||||||
|
Duration duration
|
||||||
|
}
|
||||||
|
type songs struct {
|
||||||
|
Song []song
|
||||||
|
}
|
||||||
|
var favorites songs
|
||||||
|
if _, err := toml.Decode(blob, &favorites); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range favorites.Song {
|
||||||
|
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
And you'll also need a `duration` type that satisfies the
|
||||||
|
`encoding.TextUnmarshaler` interface:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *duration) UnmarshalText(text []byte) error {
|
||||||
|
var err error
|
||||||
|
d.Duration, err = time.ParseDuration(string(text))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### More complex usage
|
||||||
|
|
||||||
|
Here's an example of how to load the example from the official spec page:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
# This is a TOML document. Boom.
|
||||||
|
|
||||||
|
title = "TOML Example"
|
||||||
|
|
||||||
|
[owner]
|
||||||
|
name = "Tom Preston-Werner"
|
||||||
|
organization = "GitHub"
|
||||||
|
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
|
||||||
|
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
|
||||||
|
|
||||||
|
[database]
|
||||||
|
server = "192.168.1.1"
|
||||||
|
ports = [ 8001, 8001, 8002 ]
|
||||||
|
connection_max = 5000
|
||||||
|
enabled = true
|
||||||
|
|
||||||
|
[servers]
|
||||||
|
|
||||||
|
# You can indent as you please. Tabs or spaces. TOML don't care.
|
||||||
|
[servers.alpha]
|
||||||
|
ip = "10.0.0.1"
|
||||||
|
dc = "eqdc10"
|
||||||
|
|
||||||
|
[servers.beta]
|
||||||
|
ip = "10.0.0.2"
|
||||||
|
dc = "eqdc10"
|
||||||
|
|
||||||
|
[clients]
|
||||||
|
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
|
||||||
|
|
||||||
|
# Line breaks are OK when inside arrays
|
||||||
|
hosts = [
|
||||||
|
"alpha",
|
||||||
|
"omega"
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
And the corresponding Go types are:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type tomlConfig struct {
|
||||||
|
Title string
|
||||||
|
Owner ownerInfo
|
||||||
|
DB database `toml:"database"`
|
||||||
|
Servers map[string]server
|
||||||
|
Clients clients
|
||||||
|
}
|
||||||
|
|
||||||
|
type ownerInfo struct {
|
||||||
|
Name string
|
||||||
|
Org string `toml:"organization"`
|
||||||
|
Bio string
|
||||||
|
DOB time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type database struct {
|
||||||
|
Server string
|
||||||
|
Ports []int
|
||||||
|
ConnMax int `toml:"connection_max"`
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type server struct {
|
||||||
|
IP string
|
||||||
|
DC string
|
||||||
|
}
|
||||||
|
|
||||||
|
type clients struct {
|
||||||
|
Data [][]interface{}
|
||||||
|
Hosts []string
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that a case insensitive match will be tried if an exact match can't be
|
||||||
|
found.
|
||||||
|
|
||||||
|
A working example of the above can be found in `_examples/example.{go,toml}`.
|
@ -0,0 +1,509 @@
|
|||||||
|
package toml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func e(format string, args ...interface{}) error {
|
||||||
|
return fmt.Errorf("toml: "+format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshaler is the interface implemented by objects that can unmarshal a
|
||||||
|
// TOML description of themselves.
|
||||||
|
type Unmarshaler interface {
|
||||||
|
UnmarshalTOML(interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
|
||||||
|
func Unmarshal(p []byte, v interface{}) error {
|
||||||
|
_, err := Decode(string(p), v)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Primitive is a TOML value that hasn't been decoded into a Go value.
|
||||||
|
// When using the various `Decode*` functions, the type `Primitive` may
|
||||||
|
// be given to any value, and its decoding will be delayed.
|
||||||
|
//
|
||||||
|
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
|
||||||
|
//
|
||||||
|
// The underlying representation of a `Primitive` value is subject to change.
|
||||||
|
// Do not rely on it.
|
||||||
|
//
|
||||||
|
// N.B. Primitive values are still parsed, so using them will only avoid
|
||||||
|
// the overhead of reflection. They can be useful when you don't know the
|
||||||
|
// exact type of TOML data until run time.
|
||||||
|
type Primitive struct {
|
||||||
|
undecoded interface{}
|
||||||
|
context Key
|
||||||
|
}
|
||||||
|
|
||||||
|
// DEPRECATED!
|
||||||
|
//
|
||||||
|
// Use MetaData.PrimitiveDecode instead.
|
||||||
|
func PrimitiveDecode(primValue Primitive, v interface{}) error {
|
||||||
|
md := MetaData{decoded: make(map[string]bool)}
|
||||||
|
return md.unify(primValue.undecoded, rvalue(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrimitiveDecode is just like the other `Decode*` functions, except it
|
||||||
|
// decodes a TOML value that has already been parsed. Valid primitive values
|
||||||
|
// can *only* be obtained from values filled by the decoder functions,
|
||||||
|
// including this method. (i.e., `v` may contain more `Primitive`
|
||||||
|
// values.)
|
||||||
|
//
|
||||||
|
// Meta data for primitive values is included in the meta data returned by
|
||||||
|
// the `Decode*` functions with one exception: keys returned by the Undecoded
|
||||||
|
// method will only reflect keys that were decoded. Namely, any keys hidden
|
||||||
|
// behind a Primitive will be considered undecoded. Executing this method will
|
||||||
|
// update the undecoded keys in the meta data. (See the example.)
|
||||||
|
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
|
||||||
|
md.context = primValue.context
|
||||||
|
defer func() { md.context = nil }()
|
||||||
|
return md.unify(primValue.undecoded, rvalue(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode will decode the contents of `data` in TOML format into a pointer
|
||||||
|
// `v`.
|
||||||
|
//
|
||||||
|
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
|
||||||
|
// used interchangeably.)
|
||||||
|
//
|
||||||
|
// TOML arrays of tables correspond to either a slice of structs or a slice
|
||||||
|
// of maps.
|
||||||
|
//
|
||||||
|
// TOML datetimes correspond to Go `time.Time` values.
|
||||||
|
//
|
||||||
|
// All other TOML types (float, string, int, bool and array) correspond
|
||||||
|
// to the obvious Go types.
|
||||||
|
//
|
||||||
|
// An exception to the above rules is if a type implements the
|
||||||
|
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
|
||||||
|
// (floats, strings, integers, booleans and datetimes) will be converted to
|
||||||
|
// a byte string and given to the value's UnmarshalText method. See the
|
||||||
|
// Unmarshaler example for a demonstration with time duration strings.
|
||||||
|
//
|
||||||
|
// Key mapping
|
||||||
|
//
|
||||||
|
// TOML keys can map to either keys in a Go map or field names in a Go
|
||||||
|
// struct. The special `toml` struct tag may be used to map TOML keys to
|
||||||
|
// struct fields that don't match the key name exactly. (See the example.)
|
||||||
|
// A case insensitive match to struct names will be tried if an exact match
|
||||||
|
// can't be found.
|
||||||
|
//
|
||||||
|
// The mapping between TOML values and Go values is loose. That is, there
|
||||||
|
// may exist TOML values that cannot be placed into your representation, and
|
||||||
|
// there may be parts of your representation that do not correspond to
|
||||||
|
// TOML values. This loose mapping can be made stricter by using the IsDefined
|
||||||
|
// and/or Undecoded methods on the MetaData returned.
|
||||||
|
//
|
||||||
|
// This decoder will not handle cyclic types. If a cyclic type is passed,
|
||||||
|
// `Decode` will not terminate.
|
||||||
|
func Decode(data string, v interface{}) (MetaData, error) {
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if rv.Kind() != reflect.Ptr {
|
||||||
|
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
|
||||||
|
}
|
||||||
|
if rv.IsNil() {
|
||||||
|
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
|
||||||
|
}
|
||||||
|
p, err := parse(data)
|
||||||
|
if err != nil {
|
||||||
|
return MetaData{}, err
|
||||||
|
}
|
||||||
|
md := MetaData{
|
||||||
|
p.mapping, p.types, p.ordered,
|
||||||
|
make(map[string]bool, len(p.ordered)), nil,
|
||||||
|
}
|
||||||
|
return md, md.unify(p.mapping, indirect(rv))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeFile is just like Decode, except it will automatically read the
|
||||||
|
// contents of the file at `fpath` and decode it for you.
|
||||||
|
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
|
||||||
|
bs, err := ioutil.ReadFile(fpath)
|
||||||
|
if err != nil {
|
||||||
|
return MetaData{}, err
|
||||||
|
}
|
||||||
|
return Decode(string(bs), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeReader is just like Decode, except it will consume all bytes
|
||||||
|
// from the reader and decode it for you.
|
||||||
|
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
|
||||||
|
bs, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return MetaData{}, err
|
||||||
|
}
|
||||||
|
return Decode(string(bs), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unify performs a sort of type unification based on the structure of `rv`,
|
||||||
|
// which is the client representation.
|
||||||
|
//
|
||||||
|
// Any type mismatch produces an error. Finding a type that we don't know
|
||||||
|
// how to handle produces an unsupported type error.
|
||||||
|
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
|
||||||
|
|
||||||
|
// Special case. Look for a `Primitive` value.
|
||||||
|
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
|
||||||
|
// Save the undecoded data and the key context into the primitive
|
||||||
|
// value.
|
||||||
|
context := make(Key, len(md.context))
|
||||||
|
copy(context, md.context)
|
||||||
|
rv.Set(reflect.ValueOf(Primitive{
|
||||||
|
undecoded: data,
|
||||||
|
context: context,
|
||||||
|
}))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case. Unmarshaler Interface support.
|
||||||
|
if rv.CanAddr() {
|
||||||
|
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
|
||||||
|
return v.UnmarshalTOML(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case. Handle time.Time values specifically.
|
||||||
|
// TODO: Remove this code when we decide to drop support for Go 1.1.
|
||||||
|
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
|
||||||
|
// interfaces.
|
||||||
|
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
|
||||||
|
return md.unifyDatetime(data, rv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case. Look for a value satisfying the TextUnmarshaler interface.
|
||||||
|
if v, ok := rv.Interface().(TextUnmarshaler); ok {
|
||||||
|
return md.unifyText(data, v)
|
||||||
|
}
|
||||||
|
// BUG(burntsushi)
|
||||||
|
// The behavior here is incorrect whenever a Go type satisfies the
|
||||||
|
// encoding.TextUnmarshaler interface but also corresponds to a TOML
|
||||||
|
// hash or array. In particular, the unmarshaler should only be applied
|
||||||
|
// to primitive TOML values. But at this point, it will be applied to
|
||||||
|
// all kinds of values and produce an incorrect error whenever those values
|
||||||
|
// are hashes or arrays (including arrays of tables).
|
||||||
|
|
||||||
|
k := rv.Kind()
|
||||||
|
|
||||||
|
// laziness
|
||||||
|
if k >= reflect.Int && k <= reflect.Uint64 {
|
||||||
|
return md.unifyInt(data, rv)
|
||||||
|
}
|
||||||
|
switch k {
|
||||||
|
case reflect.Ptr:
|
||||||
|
elem := reflect.New(rv.Type().Elem())
|
||||||
|
err := md.unify(data, reflect.Indirect(elem))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rv.Set(elem)
|
||||||
|
return nil
|
||||||
|
case reflect.Struct:
|
||||||
|
return md.unifyStruct(data, rv)
|
||||||
|
case reflect.Map:
|
||||||
|
return md.unifyMap(data, rv)
|
||||||
|
case reflect.Array:
|
||||||
|
return md.unifyArray(data, rv)
|
||||||
|
case reflect.Slice:
|
||||||
|
return md.unifySlice(data, rv)
|
||||||
|
case reflect.String:
|
||||||
|
return md.unifyString(data, rv)
|
||||||
|
case reflect.Bool:
|
||||||
|
return md.unifyBool(data, rv)
|
||||||
|
case reflect.Interface:
|
||||||
|
// we only support empty interfaces.
|
||||||
|
if rv.NumMethod() > 0 {
|
||||||
|
return e("unsupported type %s", rv.Type())
|
||||||
|
}
|
||||||
|
return md.unifyAnything(data, rv)
|
||||||
|
case reflect.Float32:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Float64:
|
||||||
|
return md.unifyFloat64(data, rv)
|
||||||
|
}
|
||||||
|
return e("unsupported type %s", rv.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
|
||||||
|
tmap, ok := mapping.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
if mapping == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e("type mismatch for %s: expected table but found %T",
|
||||||
|
rv.Type().String(), mapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, datum := range tmap {
|
||||||
|
var f *field
|
||||||
|
fields := cachedTypeFields(rv.Type())
|
||||||
|
for i := range fields {
|
||||||
|
ff := &fields[i]
|
||||||
|
if ff.name == key {
|
||||||
|
f = ff
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if f == nil && strings.EqualFold(ff.name, key) {
|
||||||
|
f = ff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if f != nil {
|
||||||
|
subv := rv
|
||||||
|
for _, i := range f.index {
|
||||||
|
subv = indirect(subv.Field(i))
|
||||||
|
}
|
||||||
|
if isUnifiable(subv) {
|
||||||
|
md.decoded[md.context.add(key).String()] = true
|
||||||
|
md.context = append(md.context, key)
|
||||||
|
if err := md.unify(datum, subv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
md.context = md.context[0 : len(md.context)-1]
|
||||||
|
} else if f.name != "" {
|
||||||
|
// Bad user! No soup for you!
|
||||||
|
return e("cannot write unexported field %s.%s",
|
||||||
|
rv.Type().String(), f.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
|
||||||
|
tmap, ok := mapping.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
if tmap == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("map", mapping)
|
||||||
|
}
|
||||||
|
if rv.IsNil() {
|
||||||
|
rv.Set(reflect.MakeMap(rv.Type()))
|
||||||
|
}
|
||||||
|
for k, v := range tmap {
|
||||||
|
md.decoded[md.context.add(k).String()] = true
|
||||||
|
md.context = append(md.context, k)
|
||||||
|
|
||||||
|
rvkey := indirect(reflect.New(rv.Type().Key()))
|
||||||
|
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
|
||||||
|
if err := md.unify(v, rvval); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
md.context = md.context[0 : len(md.context)-1]
|
||||||
|
|
||||||
|
rvkey.SetString(k)
|
||||||
|
rv.SetMapIndex(rvkey, rvval)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
|
||||||
|
datav := reflect.ValueOf(data)
|
||||||
|
if datav.Kind() != reflect.Slice {
|
||||||
|
if !datav.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("slice", data)
|
||||||
|
}
|
||||||
|
sliceLen := datav.Len()
|
||||||
|
if sliceLen != rv.Len() {
|
||||||
|
return e("expected array length %d; got TOML array of length %d",
|
||||||
|
rv.Len(), sliceLen)
|
||||||
|
}
|
||||||
|
return md.unifySliceArray(datav, rv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
|
||||||
|
datav := reflect.ValueOf(data)
|
||||||
|
if datav.Kind() != reflect.Slice {
|
||||||
|
if !datav.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("slice", data)
|
||||||
|
}
|
||||||
|
n := datav.Len()
|
||||||
|
if rv.IsNil() || rv.Cap() < n {
|
||||||
|
rv.Set(reflect.MakeSlice(rv.Type(), n, n))
|
||||||
|
}
|
||||||
|
rv.SetLen(n)
|
||||||
|
return md.unifySliceArray(datav, rv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
|
||||||
|
sliceLen := data.Len()
|
||||||
|
for i := 0; i < sliceLen; i++ {
|
||||||
|
v := data.Index(i).Interface()
|
||||||
|
sliceval := indirect(rv.Index(i))
|
||||||
|
if err := md.unify(v, sliceval); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
|
||||||
|
if _, ok := data.(time.Time); ok {
|
||||||
|
rv.Set(reflect.ValueOf(data))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("time.Time", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
|
||||||
|
if s, ok := data.(string); ok {
|
||||||
|
rv.SetString(s)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("string", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
|
||||||
|
if num, ok := data.(float64); ok {
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Float32:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Float64:
|
||||||
|
rv.SetFloat(num)
|
||||||
|
default:
|
||||||
|
panic("bug")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("float", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
|
||||||
|
if num, ok := data.(int64); ok {
|
||||||
|
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int64:
|
||||||
|
// No bounds checking necessary.
|
||||||
|
case reflect.Int8:
|
||||||
|
if num < math.MinInt8 || num > math.MaxInt8 {
|
||||||
|
return e("value %d is out of range for int8", num)
|
||||||
|
}
|
||||||
|
case reflect.Int16:
|
||||||
|
if num < math.MinInt16 || num > math.MaxInt16 {
|
||||||
|
return e("value %d is out of range for int16", num)
|
||||||
|
}
|
||||||
|
case reflect.Int32:
|
||||||
|
if num < math.MinInt32 || num > math.MaxInt32 {
|
||||||
|
return e("value %d is out of range for int32", num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rv.SetInt(num)
|
||||||
|
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
|
||||||
|
unum := uint64(num)
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Uint, reflect.Uint64:
|
||||||
|
// No bounds checking necessary.
|
||||||
|
case reflect.Uint8:
|
||||||
|
if num < 0 || unum > math.MaxUint8 {
|
||||||
|
return e("value %d is out of range for uint8", num)
|
||||||
|
}
|
||||||
|
case reflect.Uint16:
|
||||||
|
if num < 0 || unum > math.MaxUint16 {
|
||||||
|
return e("value %d is out of range for uint16", num)
|
||||||
|
}
|
||||||
|
case reflect.Uint32:
|
||||||
|
if num < 0 || unum > math.MaxUint32 {
|
||||||
|
return e("value %d is out of range for uint32", num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rv.SetUint(unum)
|
||||||
|
} else {
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("integer", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
|
||||||
|
if b, ok := data.(bool); ok {
|
||||||
|
rv.SetBool(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return badtype("boolean", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
|
||||||
|
rv.Set(reflect.ValueOf(data))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
|
||||||
|
var s string
|
||||||
|
switch sdata := data.(type) {
|
||||||
|
case TextMarshaler:
|
||||||
|
text, err := sdata.MarshalText()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s = string(text)
|
||||||
|
case fmt.Stringer:
|
||||||
|
s = sdata.String()
|
||||||
|
case string:
|
||||||
|
s = sdata
|
||||||
|
case bool:
|
||||||
|
s = fmt.Sprintf("%v", sdata)
|
||||||
|
case int64:
|
||||||
|
s = fmt.Sprintf("%d", sdata)
|
||||||
|
case float64:
|
||||||
|
s = fmt.Sprintf("%f", sdata)
|
||||||
|
default:
|
||||||
|
return badtype("primitive (string-like)", data)
|
||||||
|
}
|
||||||
|
if err := v.UnmarshalText([]byte(s)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
|
||||||
|
func rvalue(v interface{}) reflect.Value {
|
||||||
|
return indirect(reflect.ValueOf(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// indirect returns the value pointed to by a pointer.
|
||||||
|
// Pointers are followed until the value is not a pointer.
|
||||||
|
// New values are allocated for each nil pointer.
|
||||||
|
//
|
||||||
|
// An exception to this rule is if the value satisfies an interface of
|
||||||
|
// interest to us (like encoding.TextUnmarshaler).
|
||||||
|
func indirect(v reflect.Value) reflect.Value {
|
||||||
|
if v.Kind() != reflect.Ptr {
|
||||||
|
if v.CanSet() {
|
||||||
|
pv := v.Addr()
|
||||||
|
if _, ok := pv.Interface().(TextUnmarshaler); ok {
|
||||||
|
return pv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v.IsNil() {
|
||||||
|
v.Set(reflect.New(v.Type().Elem()))
|
||||||
|
}
|
||||||
|
return indirect(reflect.Indirect(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUnifiable(rv reflect.Value) bool {
|
||||||
|
if rv.CanSet() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := rv.Interface().(TextUnmarshaler); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func badtype(expected string, data interface{}) error {
|
||||||
|
return e("cannot load TOML value of type %T into a Go %s", data, expected)
|
||||||
|
}
|
@ -0,0 +1,121 @@
|
|||||||
|
package toml
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// MetaData allows access to meta information about TOML data that may not
|
||||||
|
// be inferrable via reflection. In particular, whether a key has been defined
|
||||||
|
// and the TOML type of a key.
|
||||||
|
type MetaData struct {
|
||||||
|
mapping map[string]interface{}
|
||||||
|
types map[string]tomlType
|
||||||
|
keys []Key
|
||||||
|
decoded map[string]bool
|
||||||
|
context Key // Used only during decoding.
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDefined returns true if the key given exists in the TOML data. The key
|
||||||
|
// should be specified hierarchially. e.g.,
|
||||||
|
//
|
||||||
|
// // access the TOML key 'a.b.c'
|
||||||
|
// IsDefined("a", "b", "c")
|
||||||
|
//
|
||||||
|
// IsDefined will return false if an empty key given. Keys are case sensitive.
|
||||||
|
func (md *MetaData) IsDefined(key ...string) bool {
|
||||||
|
if len(key) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var hash map[string]interface{}
|
||||||
|
var ok bool
|
||||||
|
var hashOrVal interface{} = md.mapping
|
||||||
|
for _, k := range key {
|
||||||
|
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hashOrVal, ok = hash[k]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns a string representation of the type of the key specified.
|
||||||
|
//
|
||||||
|
// Type will return the empty string if given an empty key or a key that
|
||||||
|
// does not exist. Keys are case sensitive.
|
||||||
|
func (md *MetaData) Type(key ...string) string {
|
||||||
|
fullkey := strings.Join(key, ".")
|
||||||
|
if typ, ok := md.types[fullkey]; ok {
|
||||||
|
return typ.typeString()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
|
||||||
|
// to get values of this type.
|
||||||
|
type Key []string
|
||||||
|
|
||||||
|
func (k Key) String() string {
|
||||||
|
return strings.Join(k, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k Key) maybeQuotedAll() string {
|
||||||
|
var ss []string
|
||||||
|
for i := range k {
|
||||||
|
ss = append(ss, k.maybeQuoted(i))
|
||||||
|
}
|
||||||
|
return strings.Join(ss, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k Key) maybeQuoted(i int) string {
|
||||||
|
quote := false
|
||||||
|
for _, c := range k[i] {
|
||||||
|
if !isBareKeyChar(c) {
|
||||||
|
quote = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if quote {
|
||||||
|
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
|
||||||
|
}
|
||||||
|
return k[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k Key) add(piece string) Key {
|
||||||
|
newKey := make(Key, len(k)+1)
|
||||||
|
copy(newKey, k)
|
||||||
|
newKey[len(k)] = piece
|
||||||
|
return newKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keys returns a slice of every key in the TOML data, including key groups.
|
||||||
|
// Each key is itself a slice, where the first element is the top of the
|
||||||
|
// hierarchy and the last is the most specific.
|
||||||
|
//
|
||||||
|
// The list will have the same order as the keys appeared in the TOML data.
|
||||||
|
//
|
||||||
|
// All keys returned are non-empty.
|
||||||
|
func (md *MetaData) Keys() []Key {
|
||||||
|
return md.keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// Undecoded returns all keys that have not been decoded in the order in which
|
||||||
|
// they appear in the original TOML document.
|
||||||
|
//
|
||||||
|
// This includes keys that haven't been decoded because of a Primitive value.
|
||||||
|
// Once the Primitive value is decoded, the keys will be considered decoded.
|
||||||
|
//
|
||||||
|
// Also note that decoding into an empty interface will result in no decoding,
|
||||||
|
// and so no keys will be considered decoded.
|
||||||
|
//
|
||||||
|
// In this sense, the Undecoded keys correspond to keys in the TOML document
|
||||||
|
// that do not have a concrete type in your representation.
|
||||||
|
func (md *MetaData) Undecoded() []Key {
|
||||||
|
undecoded := make([]Key, 0, len(md.keys))
|
||||||
|
for _, key := range md.keys {
|
||||||
|
if !md.decoded[key.String()] {
|
||||||
|
undecoded = append(undecoded, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return undecoded
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
/*
|
||||||
|
Package toml provides facilities for decoding and encoding TOML configuration
|
||||||
|
files via reflection. There is also support for delaying decoding with
|
||||||
|
the Primitive type, and querying the set of keys in a TOML document with the
|
||||||
|
MetaData type.
|
||||||
|
|
||||||
|
The specification implemented: https://github.com/toml-lang/toml
|
||||||
|
|
||||||
|
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
|
||||||
|
whether a file is a valid TOML document. It can also be used to print the
|
||||||
|
type of each key in a TOML document.
|
||||||
|
|
||||||
|
Testing
|
||||||
|
|
||||||
|
There are two important types of tests used for this package. The first is
|
||||||
|
contained inside '*_test.go' files and uses the standard Go unit testing
|
||||||
|
framework. These tests are primarily devoted to holistically testing the
|
||||||
|
decoder and encoder.
|
||||||
|
|
||||||
|
The second type of testing is used to verify the implementation's adherence
|
||||||
|
to the TOML specification. These tests have been factored into their own
|
||||||
|
project: https://github.com/BurntSushi/toml-test
|
||||||
|
|
||||||
|
The reason the tests are in a separate project is so that they can be used by
|
||||||
|
any implementation of TOML. Namely, it is language agnostic.
|
||||||
|
*/
|
||||||
|
package toml
|
@ -0,0 +1,568 @@
|
|||||||
|
package toml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tomlEncodeError struct{ error }
|
||||||
|
|
||||||
|
var (
|
||||||
|
errArrayMixedElementTypes = errors.New(
|
||||||
|
"toml: cannot encode array with mixed element types")
|
||||||
|
errArrayNilElement = errors.New(
|
||||||
|
"toml: cannot encode array with nil element")
|
||||||
|
errNonString = errors.New(
|
||||||
|
"toml: cannot encode a map with non-string key type")
|
||||||
|
errAnonNonStruct = errors.New(
|
||||||
|
"toml: cannot encode an anonymous field that is not a struct")
|
||||||
|
errArrayNoTable = errors.New(
|
||||||
|
"toml: TOML array element cannot contain a table")
|
||||||
|
errNoKey = errors.New(
|
||||||
|
"toml: top-level values must be Go maps or structs")
|
||||||
|
errAnything = errors.New("") // used in testing
|
||||||
|
)
|
||||||
|
|
||||||
|
var quotedReplacer = strings.NewReplacer(
|
||||||
|
"\t", "\\t",
|
||||||
|
"\n", "\\n",
|
||||||
|
"\r", "\\r",
|
||||||
|
"\"", "\\\"",
|
||||||
|
"\\", "\\\\",
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encoder controls the encoding of Go values to a TOML document to some
|
||||||
|
// io.Writer.
|
||||||
|
//
|
||||||
|
// The indentation level can be controlled with the Indent field.
|
||||||
|
type Encoder struct {
|
||||||
|
// A single indentation level. By default it is two spaces.
|
||||||
|
Indent string
|
||||||
|
|
||||||
|
// hasWritten is whether we have written any output to w yet.
|
||||||
|
hasWritten bool
|
||||||
|
w *bufio.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
|
||||||
|
// given. By default, a single indentation level is 2 spaces.
|
||||||
|
func NewEncoder(w io.Writer) *Encoder {
|
||||||
|
return &Encoder{
|
||||||
|
w: bufio.NewWriter(w),
|
||||||
|
Indent: " ",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode writes a TOML representation of the Go value to the underlying
|
||||||
|
// io.Writer. If the value given cannot be encoded to a valid TOML document,
|
||||||
|
// then an error is returned.
|
||||||
|
//
|
||||||
|
// The mapping between Go values and TOML values should be precisely the same
|
||||||
|
// as for the Decode* functions. Similarly, the TextMarshaler interface is
|
||||||
|
// supported by encoding the resulting bytes as strings. (If you want to write
|
||||||
|
// arbitrary binary data then you will need to use something like base64 since
|
||||||
|
// TOML does not have any binary types.)
|
||||||
|
//
|
||||||
|
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
|
||||||
|
// sub-hashes are encoded first.
|
||||||
|
//
|
||||||
|
// If a Go map is encoded, then its keys are sorted alphabetically for
|
||||||
|
// deterministic output. More control over this behavior may be provided if
|
||||||
|
// there is demand for it.
|
||||||
|
//
|
||||||
|
// Encoding Go values without a corresponding TOML representation---like map
|
||||||
|
// types with non-string keys---will cause an error to be returned. Similarly
|
||||||
|
// for mixed arrays/slices, arrays/slices with nil elements, embedded
|
||||||
|
// non-struct types and nested slices containing maps or structs.
|
||||||
|
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
|
||||||
|
// and so is []map[string][]string.)
|
||||||
|
func (enc *Encoder) Encode(v interface{}) error {
|
||||||
|
rv := eindirect(reflect.ValueOf(v))
|
||||||
|
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return enc.w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if terr, ok := r.(tomlEncodeError); ok {
|
||||||
|
err = terr.error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
enc.encode(key, rv)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) encode(key Key, rv reflect.Value) {
|
||||||
|
// Special case. Time needs to be in ISO8601 format.
|
||||||
|
// Special case. If we can marshal the type to text, then we used that.
|
||||||
|
// Basically, this prevents the encoder for handling these types as
|
||||||
|
// generic structs (or whatever the underlying type of a TextMarshaler is).
|
||||||
|
switch rv.Interface().(type) {
|
||||||
|
case time.Time, TextMarshaler:
|
||||||
|
enc.keyEqElement(key, rv)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
k := rv.Kind()
|
||||||
|
switch k {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
|
||||||
|
reflect.Int64,
|
||||||
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
|
||||||
|
reflect.Uint64,
|
||||||
|
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
|
||||||
|
enc.keyEqElement(key, rv)
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
|
||||||
|
enc.eArrayOfTables(key, rv)
|
||||||
|
} else {
|
||||||
|
enc.keyEqElement(key, rv)
|
||||||
|
}
|
||||||
|
case reflect.Interface:
|
||||||
|
if rv.IsNil() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
enc.encode(key, rv.Elem())
|
||||||
|
case reflect.Map:
|
||||||
|
if rv.IsNil() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
enc.eTable(key, rv)
|
||||||
|
case reflect.Ptr:
|
||||||
|
if rv.IsNil() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
enc.encode(key, rv.Elem())
|
||||||
|
case reflect.Struct:
|
||||||
|
enc.eTable(key, rv)
|
||||||
|
default:
|
||||||
|
panic(e("unsupported type for key '%s': %s", key, k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// eElement encodes any value that can be an array element (primitives and
|
||||||
|
// arrays).
|
||||||
|
func (enc *Encoder) eElement(rv reflect.Value) {
|
||||||
|
switch v := rv.Interface().(type) {
|
||||||
|
case time.Time:
|
||||||
|
// Special case time.Time as a primitive. Has to come before
|
||||||
|
// TextMarshaler below because time.Time implements
|
||||||
|
// encoding.TextMarshaler, but we need to always use UTC.
|
||||||
|
enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
|
||||||
|
return
|
||||||
|
case TextMarshaler:
|
||||||
|
// Special case. Use text marshaler if it's available for this value.
|
||||||
|
if s, err := v.MarshalText(); err != nil {
|
||||||
|
encPanic(err)
|
||||||
|
} else {
|
||||||
|
enc.writeQuoted(string(s))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
enc.wf(strconv.FormatBool(rv.Bool()))
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
|
||||||
|
reflect.Int64:
|
||||||
|
enc.wf(strconv.FormatInt(rv.Int(), 10))
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16,
|
||||||
|
reflect.Uint32, reflect.Uint64:
|
||||||
|
enc.wf(strconv.FormatUint(rv.Uint(), 10))
|
||||||
|
case reflect.Float32:
|
||||||
|
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
|
||||||
|
case reflect.Float64:
|
||||||
|
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
enc.eArrayOrSliceElement(rv)
|
||||||
|
case reflect.Interface:
|
||||||
|
enc.eElement(rv.Elem())
|
||||||
|
case reflect.String:
|
||||||
|
enc.writeQuoted(rv.String())
|
||||||
|
default:
|
||||||
|
panic(e("unexpected primitive type: %s", rv.Kind()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// By the TOML spec, all floats must have a decimal with at least one
|
||||||
|
// number on either side.
|
||||||
|
func floatAddDecimal(fstr string) string {
|
||||||
|
if !strings.Contains(fstr, ".") {
|
||||||
|
return fstr + ".0"
|
||||||
|
}
|
||||||
|
return fstr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) writeQuoted(s string) {
|
||||||
|
enc.wf("\"%s\"", quotedReplacer.Replace(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
|
||||||
|
length := rv.Len()
|
||||||
|
enc.wf("[")
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
elem := rv.Index(i)
|
||||||
|
enc.eElement(elem)
|
||||||
|
if i != length-1 {
|
||||||
|
enc.wf(", ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
enc.wf("]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
|
||||||
|
if len(key) == 0 {
|
||||||
|
encPanic(errNoKey)
|
||||||
|
}
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
trv := rv.Index(i)
|
||||||
|
if isNil(trv) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
panicIfInvalidKey(key)
|
||||||
|
enc.newline()
|
||||||
|
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
|
||||||
|
enc.newline()
|
||||||
|
enc.eMapOrStruct(key, trv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
|
||||||
|
panicIfInvalidKey(key)
|
||||||
|
if len(key) == 1 {
|
||||||
|
// Output an extra newline between top-level tables.
|
||||||
|
// (The newline isn't written if nothing else has been written though.)
|
||||||
|
enc.newline()
|
||||||
|
}
|
||||||
|
if len(key) > 0 {
|
||||||
|
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
|
||||||
|
enc.newline()
|
||||||
|
}
|
||||||
|
enc.eMapOrStruct(key, rv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
|
||||||
|
switch rv := eindirect(rv); rv.Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
enc.eMap(key, rv)
|
||||||
|
case reflect.Struct:
|
||||||
|
enc.eStruct(key, rv)
|
||||||
|
default:
|
||||||
|
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
|
||||||
|
rt := rv.Type()
|
||||||
|
if rt.Key().Kind() != reflect.String {
|
||||||
|
encPanic(errNonString)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort keys so that we have deterministic output. And write keys directly
|
||||||
|
// underneath this key first, before writing sub-structs or sub-maps.
|
||||||
|
var mapKeysDirect, mapKeysSub []string
|
||||||
|
for _, mapKey := range rv.MapKeys() {
|
||||||
|
k := mapKey.String()
|
||||||
|
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
|
||||||
|
mapKeysSub = append(mapKeysSub, k)
|
||||||
|
} else {
|
||||||
|
mapKeysDirect = append(mapKeysDirect, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var writeMapKeys = func(mapKeys []string) {
|
||||||
|
sort.Strings(mapKeys)
|
||||||
|
for _, mapKey := range mapKeys {
|
||||||
|
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
|
||||||
|
if isNil(mrv) {
|
||||||
|
// Don't write anything for nil fields.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
enc.encode(key.add(mapKey), mrv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeMapKeys(mapKeysDirect)
|
||||||
|
writeMapKeys(mapKeysSub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
|
||||||
|
// Write keys for fields directly under this key first, because if we write
|
||||||
|
// a field that creates a new table, then all keys under it will be in that
|
||||||
|
// table (not the one we're writing here).
|
||||||
|
rt := rv.Type()
|
||||||
|
var fieldsDirect, fieldsSub [][]int
|
||||||
|
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
|
||||||
|
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
|
||||||
|
for i := 0; i < rt.NumField(); i++ {
|
||||||
|
f := rt.Field(i)
|
||||||
|
// skip unexported fields
|
||||||
|
if f.PkgPath != "" && !f.Anonymous {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
frv := rv.Field(i)
|
||||||
|
if f.Anonymous {
|
||||||
|
t := f.Type
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
// Treat anonymous struct fields with
|
||||||
|
// tag names as though they are not
|
||||||
|
// anonymous, like encoding/json does.
|
||||||
|
if getOptions(f.Tag).name == "" {
|
||||||
|
addFields(t, frv, f.Index)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case reflect.Ptr:
|
||||||
|
if t.Elem().Kind() == reflect.Struct &&
|
||||||
|
getOptions(f.Tag).name == "" {
|
||||||
|
if !frv.IsNil() {
|
||||||
|
addFields(t.Elem(), frv.Elem(), f.Index)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Fall through to the normal field encoding logic below
|
||||||
|
// for non-struct anonymous fields.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if typeIsHash(tomlTypeOfGo(frv)) {
|
||||||
|
fieldsSub = append(fieldsSub, append(start, f.Index...))
|
||||||
|
} else {
|
||||||
|
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addFields(rt, rv, nil)
|
||||||
|
|
||||||
|
var writeFields = func(fields [][]int) {
|
||||||
|
for _, fieldIndex := range fields {
|
||||||
|
sft := rt.FieldByIndex(fieldIndex)
|
||||||
|
sf := rv.FieldByIndex(fieldIndex)
|
||||||
|
if isNil(sf) {
|
||||||
|
// Don't write anything for nil fields.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := getOptions(sft.Tag)
|
||||||
|
if opts.skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keyName := sft.Name
|
||||||
|
if opts.name != "" {
|
||||||
|
keyName = opts.name
|
||||||
|
}
|
||||||
|
if opts.omitempty && isEmpty(sf) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if opts.omitzero && isZero(sf) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
enc.encode(key.add(keyName), sf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeFields(fieldsDirect)
|
||||||
|
writeFields(fieldsSub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tomlTypeName returns the TOML type name of the Go value's type. It is
|
||||||
|
// used to determine whether the types of array elements are mixed (which is
|
||||||
|
// forbidden). If the Go value is nil, then it is illegal for it to be an array
|
||||||
|
// element, and valueIsNil is returned as true.
|
||||||
|
|
||||||
|
// Returns the TOML type of a Go value. The type may be `nil`, which means
|
||||||
|
// no concrete TOML type could be found.
|
||||||
|
func tomlTypeOfGo(rv reflect.Value) tomlType {
|
||||||
|
if isNil(rv) || !rv.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
return tomlBool
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
|
||||||
|
reflect.Int64,
|
||||||
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
|
||||||
|
reflect.Uint64:
|
||||||
|
return tomlInteger
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return tomlFloat
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
if typeEqual(tomlHash, tomlArrayType(rv)) {
|
||||||
|
return tomlArrayHash
|
||||||
|
}
|
||||||
|
return tomlArray
|
||||||
|
case reflect.Ptr, reflect.Interface:
|
||||||
|
return tomlTypeOfGo(rv.Elem())
|
||||||
|
case reflect.String:
|
||||||
|
return tomlString
|
||||||
|
case reflect.Map:
|
||||||
|
return tomlHash
|
||||||
|
case reflect.Struct:
|
||||||
|
switch rv.Interface().(type) {
|
||||||
|
case time.Time:
|
||||||
|
return tomlDatetime
|
||||||
|
case TextMarshaler:
|
||||||
|
return tomlString
|
||||||
|
default:
|
||||||
|
return tomlHash
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("unexpected reflect.Kind: " + rv.Kind().String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tomlArrayType returns the element type of a TOML array. The type returned
|
||||||
|
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
|
||||||
|
// slize). This function may also panic if it finds a type that cannot be
|
||||||
|
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
|
||||||
|
// nested arrays of tables).
|
||||||
|
func tomlArrayType(rv reflect.Value) tomlType {
|
||||||
|
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
firstType := tomlTypeOfGo(rv.Index(0))
|
||||||
|
if firstType == nil {
|
||||||
|
encPanic(errArrayNilElement)
|
||||||
|
}
|
||||||
|
|
||||||
|
rvlen := rv.Len()
|
||||||
|
for i := 1; i < rvlen; i++ {
|
||||||
|
elem := rv.Index(i)
|
||||||
|
switch elemType := tomlTypeOfGo(elem); {
|
||||||
|
case elemType == nil:
|
||||||
|
encPanic(errArrayNilElement)
|
||||||
|
case !typeEqual(firstType, elemType):
|
||||||
|
encPanic(errArrayMixedElementTypes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we have a nested array, then we must make sure that the nested
|
||||||
|
// array contains ONLY primitives.
|
||||||
|
// This checks arbitrarily nested arrays.
|
||||||
|
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
|
||||||
|
nest := tomlArrayType(eindirect(rv.Index(0)))
|
||||||
|
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
|
||||||
|
encPanic(errArrayNoTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstType
|
||||||
|
}
|
||||||
|
|
||||||
|
type tagOptions struct {
|
||||||
|
skip bool // "-"
|
||||||
|
name string
|
||||||
|
omitempty bool
|
||||||
|
omitzero bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOptions(tag reflect.StructTag) tagOptions {
|
||||||
|
t := tag.Get("toml")
|
||||||
|
if t == "-" {
|
||||||
|
return tagOptions{skip: true}
|
||||||
|
}
|
||||||
|
var opts tagOptions
|
||||||
|
parts := strings.Split(t, ",")
|
||||||
|
opts.name = parts[0]
|
||||||
|
for _, s := range parts[1:] {
|
||||||
|
switch s {
|
||||||
|
case "omitempty":
|
||||||
|
opts.omitempty = true
|
||||||
|
case "omitzero":
|
||||||
|
opts.omitzero = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
func isZero(rv reflect.Value) bool {
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return rv.Int() == 0
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return rv.Uint() == 0
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return rv.Float() == 0.0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isEmpty(rv reflect.Value) bool {
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
|
||||||
|
return rv.Len() == 0
|
||||||
|
case reflect.Bool:
|
||||||
|
return !rv.Bool()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) newline() {
|
||||||
|
if enc.hasWritten {
|
||||||
|
enc.wf("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
|
||||||
|
if len(key) == 0 {
|
||||||
|
encPanic(errNoKey)
|
||||||
|
}
|
||||||
|
panicIfInvalidKey(key)
|
||||||
|
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
|
||||||
|
enc.eElement(val)
|
||||||
|
enc.newline()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) wf(format string, v ...interface{}) {
|
||||||
|
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
|
||||||
|
encPanic(err)
|
||||||
|
}
|
||||||
|
enc.hasWritten = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *Encoder) indentStr(key Key) string {
|
||||||
|
return strings.Repeat(enc.Indent, len(key)-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encPanic(err error) {
|
||||||
|
panic(tomlEncodeError{err})
|
||||||
|
}
|
||||||
|
|
||||||
|
func eindirect(v reflect.Value) reflect.Value {
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Interface:
|
||||||
|
return eindirect(v.Elem())
|
||||||
|
default:
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNil(rv reflect.Value) bool {
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||||
|
return rv.IsNil()
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func panicIfInvalidKey(key Key) {
|
||||||
|
for _, k := range key {
|
||||||
|
if len(k) == 0 {
|
||||||
|
encPanic(e("Key '%s' is not a valid table name. Key names "+
|
||||||
|
"cannot be empty.", key.maybeQuotedAll()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidKeyName(s string) bool {
|
||||||
|
return len(s) != 0
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
// +build go1.2
|
||||||
|
|
||||||
|
package toml
|
||||||
|
|
||||||
|
// In order to support Go 1.1, we define our own TextMarshaler and
|
||||||
|
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
|
||||||
|
// standard library interfaces.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
|
||||||
|
// so that Go 1.1 can be supported.
|
||||||
|
type TextMarshaler encoding.TextMarshaler
|
||||||
|
|
||||||
|
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
|
||||||
|
// here so that Go 1.1 can be supported.
|
||||||
|
type TextUnmarshaler encoding.TextUnmarshaler
|
@ -0,0 +1,18 @@
|
|||||||
|
// +build !go1.2
|
||||||
|
|
||||||
|
package toml
|
||||||
|
|
||||||
|
// These interfaces were introduced in Go 1.2, so we add them manually when
|
||||||
|
// compiling for Go 1.1.
|
||||||
|
|
||||||
|
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
|
||||||
|
// so that Go 1.1 can be supported.
|
||||||
|
type TextMarshaler interface {
|
||||||
|
MarshalText() (text []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
|
||||||
|
// here so that Go 1.1 can be supported.
|
||||||
|
type TextUnmarshaler interface {
|
||||||
|
UnmarshalText(text []byte) error
|
||||||
|
}
|
@ -0,0 +1,953 @@
|
|||||||
|
package toml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
type itemType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
itemError itemType = iota
|
||||||
|
itemNIL // used in the parser to indicate no type
|
||||||
|
itemEOF
|
||||||
|
itemText
|
||||||
|
itemString
|
||||||
|
itemRawString
|
||||||
|
itemMultilineString
|
||||||
|
itemRawMultilineString
|
||||||
|
itemBool
|
||||||
|
itemInteger
|
||||||
|
itemFloat
|
||||||
|
itemDatetime
|
||||||
|
itemArray // the start of an array
|
||||||
|
itemArrayEnd
|
||||||
|
itemTableStart
|
||||||
|
itemTableEnd
|
||||||
|
itemArrayTableStart
|
||||||
|
itemArrayTableEnd
|
||||||
|
itemKeyStart
|
||||||
|
itemCommentStart
|
||||||
|
itemInlineTableStart
|
||||||
|
itemInlineTableEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
eof = 0
|
||||||
|
comma = ','
|
||||||
|
tableStart = '['
|
||||||
|
tableEnd = ']'
|
||||||
|
arrayTableStart = '['
|
||||||
|
arrayTableEnd = ']'
|
||||||
|
tableSep = '.'
|
||||||
|
keySep = '='
|
||||||
|
arrayStart = '['
|
||||||
|
arrayEnd = ']'
|
||||||
|
commentStart = '#'
|
||||||
|
stringStart = '"'
|
||||||
|
stringEnd = '"'
|
||||||
|
rawStringStart = '\''
|
||||||
|
rawStringEnd = '\''
|
||||||
|
inlineTableStart = '{'
|
||||||
|
inlineTableEnd = '}'
|
||||||
|
)
|
||||||
|
|
||||||
|
type stateFn func(lx *lexer) stateFn
|
||||||
|
|
||||||
|
type lexer struct {
|
||||||
|
input string
|
||||||
|
start int
|
||||||
|
pos int
|
||||||
|
line int
|
||||||
|
state stateFn
|
||||||
|
items chan item
|
||||||
|
|
||||||
|
// Allow for backing up up to three runes.
|
||||||
|
// This is necessary because TOML contains 3-rune tokens (""" and ''').
|
||||||
|
prevWidths [3]int
|
||||||
|
nprev int // how many of prevWidths are in use
|
||||||
|
// If we emit an eof, we can still back up, but it is not OK to call
|
||||||
|
// next again.
|
||||||
|
atEOF bool
|
||||||
|
|
||||||
|
// A stack of state functions used to maintain context.
|
||||||
|
// The idea is to reuse parts of the state machine in various places.
|
||||||
|
// For example, values can appear at the top level or within arbitrarily
|
||||||
|
// nested arrays. The last state on the stack is used after a value has
|
||||||
|
// been lexed. Similarly for comments.
|
||||||
|
stack []stateFn
|
||||||
|
}
|
||||||
|
|
||||||
|
type item struct {
|
||||||
|
typ itemType
|
||||||
|
val string
|
||||||
|
line int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lx *lexer) nextItem() item {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case item := <-lx.items:
|
||||||
|
return item
|
||||||
|
default:
|
||||||
|
lx.state = lx.state(lx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func lex(input string) *lexer {
|
||||||
|
lx := &lexer{
|
||||||
|
input: input,
|
||||||
|
state: lexTop,
|
||||||
|
line: 1,
|
||||||
|
items: make(chan item, 10),
|
||||||
|
stack: make([]stateFn, 0, 10),
|
||||||
|
}
|
||||||
|
return lx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lx *lexer) push(state stateFn) {
|
||||||
|
lx.stack = append(lx.stack, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lx *lexer) pop() stateFn {
|
||||||
|
if len(lx.stack) == 0 {
|
||||||
|
return lx.errorf("BUG in lexer: no states to pop")
|
||||||
|
}
|
||||||
|
last := lx.stack[len(lx.stack)-1]
|
||||||
|
lx.stack = lx.stack[0 : len(lx.stack)-1]
|
||||||
|
return last
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lx *lexer) current() string {
|
||||||
|
return lx.input[lx.start:lx.pos]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lx *lexer) emit(typ itemType) {
|
||||||
|
lx.items <- item{typ, lx.current(), lx.line}
|
||||||
|
lx.start = lx.pos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lx *lexer) emitTrim(typ itemType) {
|
||||||
|
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
|
||||||
|
lx.start = lx.pos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lx *lexer) next() (r rune) {
|
||||||
|
if lx.atEOF {
|
||||||
|
panic("next called after EOF")
|
||||||
|
}
|
||||||
|
if lx.pos >= len(lx.input) {
|
||||||
|
lx.atEOF = true
|
||||||
|
return eof
|
||||||
|
}
|
||||||
|
|
||||||
|
if lx.input[lx.pos] == '\n' {
|
||||||
|
lx.line++
|
||||||
|
}
|
||||||
|
lx.prevWidths[2] = lx.prevWidths[1]
|
||||||
|
lx.prevWidths[1] = lx.prevWidths[0]
|
||||||
|
if lx.nprev < 3 {
|
||||||
|
lx.nprev++
|
||||||
|
}
|
||||||
|
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
|
||||||
|
lx.prevWidths[0] = w
|
||||||
|
lx.pos += w
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// ignore skips over the pending input before this point.
|
||||||
|
func (lx *lexer) ignore() {
|
||||||
|
lx.start = lx.pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// backup steps back one rune. Can be called only twice between calls to next.
|
||||||
|
func (lx *lexer) backup() {
|
||||||
|
if lx.atEOF {
|
||||||
|
lx.atEOF = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lx.nprev < 1 {
|
||||||
|
panic("backed up too far")
|
||||||
|
}
|
||||||
|
w := lx.prevWidths[0]
|
||||||
|
lx.prevWidths[0] = lx.prevWidths[1]
|
||||||
|
lx.prevWidths[1] = lx.prevWidths[2]
|
||||||
|
lx.nprev--
|
||||||
|
lx.pos -= w
|
||||||
|
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
|
||||||
|
lx.line--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// accept consumes the next rune if it's equal to `valid`.
|
||||||
|
func (lx *lexer) accept(valid rune) bool {
|
||||||
|
if lx.next() == valid {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// peek returns but does not consume the next rune in the input.
|
||||||
|
func (lx *lexer) peek() rune {
|
||||||
|
r := lx.next()
|
||||||
|
lx.backup()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip ignores all input that matches the given predicate.
|
||||||
|
func (lx *lexer) skip(pred func(rune) bool) {
|
||||||
|
for {
|
||||||
|
r := lx.next()
|
||||||
|
if pred(r) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
lx.ignore()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorf stops all lexing by emitting an error and returning `nil`.
|
||||||
|
// Note that any value that is a character is escaped if it's a special
|
||||||
|
// character (newlines, tabs, etc.).
|
||||||
|
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
|
||||||
|
lx.items <- item{
|
||||||
|
itemError,
|
||||||
|
fmt.Sprintf(format, values...),
|
||||||
|
lx.line,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexTop consumes elements at the top level of TOML data.
|
||||||
|
func lexTop(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
if isWhitespace(r) || isNL(r) {
|
||||||
|
return lexSkip(lx, lexTop)
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case commentStart:
|
||||||
|
lx.push(lexTop)
|
||||||
|
return lexCommentStart
|
||||||
|
case tableStart:
|
||||||
|
return lexTableStart
|
||||||
|
case eof:
|
||||||
|
if lx.pos > lx.start {
|
||||||
|
return lx.errorf("unexpected EOF")
|
||||||
|
}
|
||||||
|
lx.emit(itemEOF)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, the only valid item can be a key, so we back up
|
||||||
|
// and let the key lexer do the rest.
|
||||||
|
lx.backup()
|
||||||
|
lx.push(lexTopEnd)
|
||||||
|
return lexKeyStart
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
|
||||||
|
// or a table.) It must see only whitespace, and will turn back to lexTop
|
||||||
|
// upon a newline. If it sees EOF, it will quit the lexer successfully.
|
||||||
|
func lexTopEnd(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case r == commentStart:
|
||||||
|
// a comment will read to a newline for us.
|
||||||
|
lx.push(lexTop)
|
||||||
|
return lexCommentStart
|
||||||
|
case isWhitespace(r):
|
||||||
|
return lexTopEnd
|
||||||
|
case isNL(r):
|
||||||
|
lx.ignore()
|
||||||
|
return lexTop
|
||||||
|
case r == eof:
|
||||||
|
lx.emit(itemEOF)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return lx.errorf("expected a top-level item to end with a newline, "+
|
||||||
|
"comment, or EOF, but got %q instead", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexTable lexes the beginning of a table. Namely, it makes sure that
|
||||||
|
// it starts with a character other than '.' and ']'.
|
||||||
|
// It assumes that '[' has already been consumed.
|
||||||
|
// It also handles the case that this is an item in an array of tables.
|
||||||
|
// e.g., '[[name]]'.
|
||||||
|
func lexTableStart(lx *lexer) stateFn {
|
||||||
|
if lx.peek() == arrayTableStart {
|
||||||
|
lx.next()
|
||||||
|
lx.emit(itemArrayTableStart)
|
||||||
|
lx.push(lexArrayTableEnd)
|
||||||
|
} else {
|
||||||
|
lx.emit(itemTableStart)
|
||||||
|
lx.push(lexTableEnd)
|
||||||
|
}
|
||||||
|
return lexTableNameStart
|
||||||
|
}
|
||||||
|
|
||||||
|
func lexTableEnd(lx *lexer) stateFn {
|
||||||
|
lx.emit(itemTableEnd)
|
||||||
|
return lexTopEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
func lexArrayTableEnd(lx *lexer) stateFn {
|
||||||
|
if r := lx.next(); r != arrayTableEnd {
|
||||||
|
return lx.errorf("expected end of table array name delimiter %q, "+
|
||||||
|
"but got %q instead", arrayTableEnd, r)
|
||||||
|
}
|
||||||
|
lx.emit(itemArrayTableEnd)
|
||||||
|
return lexTopEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
func lexTableNameStart(lx *lexer) stateFn {
|
||||||
|
lx.skip(isWhitespace)
|
||||||
|
switch r := lx.peek(); {
|
||||||
|
case r == tableEnd || r == eof:
|
||||||
|
return lx.errorf("unexpected end of table name " +
|
||||||
|
"(table names cannot be empty)")
|
||||||
|
case r == tableSep:
|
||||||
|
return lx.errorf("unexpected table separator " +
|
||||||
|
"(table names cannot be empty)")
|
||||||
|
case r == stringStart || r == rawStringStart:
|
||||||
|
lx.ignore()
|
||||||
|
lx.push(lexTableNameEnd)
|
||||||
|
return lexValue // reuse string lexing
|
||||||
|
default:
|
||||||
|
return lexBareTableName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexBareTableName lexes the name of a table. It assumes that at least one
|
||||||
|
// valid character for the table has already been read.
|
||||||
|
func lexBareTableName(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
if isBareKeyChar(r) {
|
||||||
|
return lexBareTableName
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemText)
|
||||||
|
return lexTableNameEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexTableNameEnd reads the end of a piece of a table name, optionally
|
||||||
|
// consuming whitespace.
|
||||||
|
func lexTableNameEnd(lx *lexer) stateFn {
|
||||||
|
lx.skip(isWhitespace)
|
||||||
|
switch r := lx.next(); {
|
||||||
|
case isWhitespace(r):
|
||||||
|
return lexTableNameEnd
|
||||||
|
case r == tableSep:
|
||||||
|
lx.ignore()
|
||||||
|
return lexTableNameStart
|
||||||
|
case r == tableEnd:
|
||||||
|
return lx.pop()
|
||||||
|
default:
|
||||||
|
return lx.errorf("expected '.' or ']' to end table name, "+
|
||||||
|
"but got %q instead", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexKeyStart consumes a key name up until the first non-whitespace character.
|
||||||
|
// lexKeyStart will ignore whitespace.
|
||||||
|
func lexKeyStart(lx *lexer) stateFn {
|
||||||
|
r := lx.peek()
|
||||||
|
switch {
|
||||||
|
case r == keySep:
|
||||||
|
return lx.errorf("unexpected key separator %q", keySep)
|
||||||
|
case isWhitespace(r) || isNL(r):
|
||||||
|
lx.next()
|
||||||
|
return lexSkip(lx, lexKeyStart)
|
||||||
|
case r == stringStart || r == rawStringStart:
|
||||||
|
lx.ignore()
|
||||||
|
lx.emit(itemKeyStart)
|
||||||
|
lx.push(lexKeyEnd)
|
||||||
|
return lexValue // reuse string lexing
|
||||||
|
default:
|
||||||
|
lx.ignore()
|
||||||
|
lx.emit(itemKeyStart)
|
||||||
|
return lexBareKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexBareKey consumes the text of a bare key. Assumes that the first character
|
||||||
|
// (which is not whitespace) has not yet been consumed.
|
||||||
|
func lexBareKey(lx *lexer) stateFn {
|
||||||
|
switch r := lx.next(); {
|
||||||
|
case isBareKeyChar(r):
|
||||||
|
return lexBareKey
|
||||||
|
case isWhitespace(r):
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemText)
|
||||||
|
return lexKeyEnd
|
||||||
|
case r == keySep:
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemText)
|
||||||
|
return lexKeyEnd
|
||||||
|
default:
|
||||||
|
return lx.errorf("bare keys cannot contain %q", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
|
||||||
|
// separator).
|
||||||
|
func lexKeyEnd(lx *lexer) stateFn {
|
||||||
|
switch r := lx.next(); {
|
||||||
|
case r == keySep:
|
||||||
|
return lexSkip(lx, lexValue)
|
||||||
|
case isWhitespace(r):
|
||||||
|
return lexSkip(lx, lexKeyEnd)
|
||||||
|
default:
|
||||||
|
return lx.errorf("expected key separator %q, but got %q instead",
|
||||||
|
keySep, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexValue starts the consumption of a value anywhere a value is expected.
|
||||||
|
// lexValue will ignore whitespace.
|
||||||
|
// After a value is lexed, the last state on the next is popped and returned.
|
||||||
|
func lexValue(lx *lexer) stateFn {
|
||||||
|
// We allow whitespace to precede a value, but NOT newlines.
|
||||||
|
// In array syntax, the array states are responsible for ignoring newlines.
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case isWhitespace(r):
|
||||||
|
return lexSkip(lx, lexValue)
|
||||||
|
case isDigit(r):
|
||||||
|
lx.backup() // avoid an extra state and use the same as above
|
||||||
|
return lexNumberOrDateStart
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case arrayStart:
|
||||||
|
lx.ignore()
|
||||||
|
lx.emit(itemArray)
|
||||||
|
return lexArrayValue
|
||||||
|
case inlineTableStart:
|
||||||
|
lx.ignore()
|
||||||
|
lx.emit(itemInlineTableStart)
|
||||||
|
return lexInlineTableValue
|
||||||
|
case stringStart:
|
||||||
|
if lx.accept(stringStart) {
|
||||||
|
if lx.accept(stringStart) {
|
||||||
|
lx.ignore() // Ignore """
|
||||||
|
return lexMultilineString
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
}
|
||||||
|
lx.ignore() // ignore the '"'
|
||||||
|
return lexString
|
||||||
|
case rawStringStart:
|
||||||
|
if lx.accept(rawStringStart) {
|
||||||
|
if lx.accept(rawStringStart) {
|
||||||
|
lx.ignore() // Ignore """
|
||||||
|
return lexMultilineRawString
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
}
|
||||||
|
lx.ignore() // ignore the "'"
|
||||||
|
return lexRawString
|
||||||
|
case '+', '-':
|
||||||
|
return lexNumberStart
|
||||||
|
case '.': // special error case, be kind to users
|
||||||
|
return lx.errorf("floats must start with a digit, not '.'")
|
||||||
|
}
|
||||||
|
if unicode.IsLetter(r) {
|
||||||
|
// Be permissive here; lexBool will give a nice error if the
|
||||||
|
// user wrote something like
|
||||||
|
// x = foo
|
||||||
|
// (i.e. not 'true' or 'false' but is something else word-like.)
|
||||||
|
lx.backup()
|
||||||
|
return lexBool
|
||||||
|
}
|
||||||
|
return lx.errorf("expected value but found %q instead", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
|
||||||
|
// have already been consumed. All whitespace and newlines are ignored.
|
||||||
|
func lexArrayValue(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case isWhitespace(r) || isNL(r):
|
||||||
|
return lexSkip(lx, lexArrayValue)
|
||||||
|
case r == commentStart:
|
||||||
|
lx.push(lexArrayValue)
|
||||||
|
return lexCommentStart
|
||||||
|
case r == comma:
|
||||||
|
return lx.errorf("unexpected comma")
|
||||||
|
case r == arrayEnd:
|
||||||
|
// NOTE(caleb): The spec isn't clear about whether you can have
|
||||||
|
// a trailing comma or not, so we'll allow it.
|
||||||
|
return lexArrayEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
lx.backup()
|
||||||
|
lx.push(lexArrayValueEnd)
|
||||||
|
return lexValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexArrayValueEnd consumes everything between the end of an array value and
|
||||||
|
// the next value (or the end of the array): it ignores whitespace and newlines
|
||||||
|
// and expects either a ',' or a ']'.
|
||||||
|
func lexArrayValueEnd(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case isWhitespace(r) || isNL(r):
|
||||||
|
return lexSkip(lx, lexArrayValueEnd)
|
||||||
|
case r == commentStart:
|
||||||
|
lx.push(lexArrayValueEnd)
|
||||||
|
return lexCommentStart
|
||||||
|
case r == comma:
|
||||||
|
lx.ignore()
|
||||||
|
return lexArrayValue // move on to the next value
|
||||||
|
case r == arrayEnd:
|
||||||
|
return lexArrayEnd
|
||||||
|
}
|
||||||
|
return lx.errorf(
|
||||||
|
"expected a comma or array terminator %q, but got %q instead",
|
||||||
|
arrayEnd, r,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexArrayEnd finishes the lexing of an array.
|
||||||
|
// It assumes that a ']' has just been consumed.
|
||||||
|
func lexArrayEnd(lx *lexer) stateFn {
|
||||||
|
lx.ignore()
|
||||||
|
lx.emit(itemArrayEnd)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexInlineTableValue consumes one key/value pair in an inline table.
|
||||||
|
// It assumes that '{' or ',' have already been consumed. Whitespace is ignored.
|
||||||
|
func lexInlineTableValue(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case isWhitespace(r):
|
||||||
|
return lexSkip(lx, lexInlineTableValue)
|
||||||
|
case isNL(r):
|
||||||
|
return lx.errorf("newlines not allowed within inline tables")
|
||||||
|
case r == commentStart:
|
||||||
|
lx.push(lexInlineTableValue)
|
||||||
|
return lexCommentStart
|
||||||
|
case r == comma:
|
||||||
|
return lx.errorf("unexpected comma")
|
||||||
|
case r == inlineTableEnd:
|
||||||
|
return lexInlineTableEnd
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
lx.push(lexInlineTableValueEnd)
|
||||||
|
return lexKeyStart
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexInlineTableValueEnd consumes everything between the end of an inline table
|
||||||
|
// key/value pair and the next pair (or the end of the table):
|
||||||
|
// it ignores whitespace and expects either a ',' or a '}'.
|
||||||
|
func lexInlineTableValueEnd(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case isWhitespace(r):
|
||||||
|
return lexSkip(lx, lexInlineTableValueEnd)
|
||||||
|
case isNL(r):
|
||||||
|
return lx.errorf("newlines not allowed within inline tables")
|
||||||
|
case r == commentStart:
|
||||||
|
lx.push(lexInlineTableValueEnd)
|
||||||
|
return lexCommentStart
|
||||||
|
case r == comma:
|
||||||
|
lx.ignore()
|
||||||
|
return lexInlineTableValue
|
||||||
|
case r == inlineTableEnd:
|
||||||
|
return lexInlineTableEnd
|
||||||
|
}
|
||||||
|
return lx.errorf("expected a comma or an inline table terminator %q, "+
|
||||||
|
"but got %q instead", inlineTableEnd, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexInlineTableEnd finishes the lexing of an inline table.
|
||||||
|
// It assumes that a '}' has just been consumed.
|
||||||
|
func lexInlineTableEnd(lx *lexer) stateFn {
|
||||||
|
lx.ignore()
|
||||||
|
lx.emit(itemInlineTableEnd)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexString consumes the inner contents of a string. It assumes that the
|
||||||
|
// beginning '"' has already been consumed and ignored.
|
||||||
|
func lexString(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case r == eof:
|
||||||
|
return lx.errorf("unexpected EOF")
|
||||||
|
case isNL(r):
|
||||||
|
return lx.errorf("strings cannot contain newlines")
|
||||||
|
case r == '\\':
|
||||||
|
lx.push(lexString)
|
||||||
|
return lexStringEscape
|
||||||
|
case r == stringEnd:
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemString)
|
||||||
|
lx.next()
|
||||||
|
lx.ignore()
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
return lexString
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexMultilineString consumes the inner contents of a string. It assumes that
|
||||||
|
// the beginning '"""' has already been consumed and ignored.
|
||||||
|
func lexMultilineString(lx *lexer) stateFn {
|
||||||
|
switch lx.next() {
|
||||||
|
case eof:
|
||||||
|
return lx.errorf("unexpected EOF")
|
||||||
|
case '\\':
|
||||||
|
return lexMultilineStringEscape
|
||||||
|
case stringEnd:
|
||||||
|
if lx.accept(stringEnd) {
|
||||||
|
if lx.accept(stringEnd) {
|
||||||
|
lx.backup()
|
||||||
|
lx.backup()
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemMultilineString)
|
||||||
|
lx.next()
|
||||||
|
lx.next()
|
||||||
|
lx.next()
|
||||||
|
lx.ignore()
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lexMultilineString
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
|
||||||
|
// It assumes that the beginning "'" has already been consumed and ignored.
|
||||||
|
func lexRawString(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch {
|
||||||
|
case r == eof:
|
||||||
|
return lx.errorf("unexpected EOF")
|
||||||
|
case isNL(r):
|
||||||
|
return lx.errorf("strings cannot contain newlines")
|
||||||
|
case r == rawStringEnd:
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemRawString)
|
||||||
|
lx.next()
|
||||||
|
lx.ignore()
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
return lexRawString
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
|
||||||
|
// a string. It assumes that the beginning "'''" has already been consumed and
|
||||||
|
// ignored.
|
||||||
|
func lexMultilineRawString(lx *lexer) stateFn {
|
||||||
|
switch lx.next() {
|
||||||
|
case eof:
|
||||||
|
return lx.errorf("unexpected EOF")
|
||||||
|
case rawStringEnd:
|
||||||
|
if lx.accept(rawStringEnd) {
|
||||||
|
if lx.accept(rawStringEnd) {
|
||||||
|
lx.backup()
|
||||||
|
lx.backup()
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemRawMultilineString)
|
||||||
|
lx.next()
|
||||||
|
lx.next()
|
||||||
|
lx.next()
|
||||||
|
lx.ignore()
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lexMultilineRawString
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexMultilineStringEscape consumes an escaped character. It assumes that the
|
||||||
|
// preceding '\\' has already been consumed.
|
||||||
|
func lexMultilineStringEscape(lx *lexer) stateFn {
|
||||||
|
// Handle the special case first:
|
||||||
|
if isNL(lx.next()) {
|
||||||
|
return lexMultilineString
|
||||||
|
}
|
||||||
|
lx.backup()
|
||||||
|
lx.push(lexMultilineString)
|
||||||
|
return lexStringEscape(lx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lexStringEscape(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
switch r {
|
||||||
|
case 'b':
|
||||||
|
fallthrough
|
||||||
|
case 't':
|
||||||
|
fallthrough
|
||||||
|
case 'n':
|
||||||
|
fallthrough
|
||||||
|
case 'f':
|
||||||
|
fallthrough
|
||||||
|
case 'r':
|
||||||
|
fallthrough
|
||||||
|
case '"':
|
||||||
|
fallthrough
|
||||||
|
case '\\':
|
||||||
|
return lx.pop()
|
||||||
|
case 'u':
|
||||||
|
return lexShortUnicodeEscape
|
||||||
|
case 'U':
|
||||||
|
return lexLongUnicodeEscape
|
||||||
|
}
|
||||||
|
return lx.errorf("invalid escape character %q; only the following "+
|
||||||
|
"escape characters are allowed: "+
|
||||||
|
`\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lexShortUnicodeEscape(lx *lexer) stateFn {
|
||||||
|
var r rune
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
r = lx.next()
|
||||||
|
if !isHexadecimal(r) {
|
||||||
|
return lx.errorf(`expected four hexadecimal digits after '\u', `+
|
||||||
|
"but got %q instead", lx.current())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func lexLongUnicodeEscape(lx *lexer) stateFn {
|
||||||
|
var r rune
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
r = lx.next()
|
||||||
|
if !isHexadecimal(r) {
|
||||||
|
return lx.errorf(`expected eight hexadecimal digits after '\U', `+
|
||||||
|
"but got %q instead", lx.current())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexNumberOrDateStart consumes either an integer, a float, or datetime.
|
||||||
|
func lexNumberOrDateStart(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
if isDigit(r) {
|
||||||
|
return lexNumberOrDate
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case '_':
|
||||||
|
return lexNumber
|
||||||
|
case 'e', 'E':
|
||||||
|
return lexFloat
|
||||||
|
case '.':
|
||||||
|
return lx.errorf("floats must start with a digit, not '.'")
|
||||||
|
}
|
||||||
|
return lx.errorf("expected a digit but got %q", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexNumberOrDate consumes either an integer, float or datetime.
|
||||||
|
func lexNumberOrDate(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
if isDigit(r) {
|
||||||
|
return lexNumberOrDate
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case '-':
|
||||||
|
return lexDatetime
|
||||||
|
case '_':
|
||||||
|
return lexNumber
|
||||||
|
case '.', 'e', 'E':
|
||||||
|
return lexFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemInteger)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexDatetime consumes a Datetime, to a first approximation.
|
||||||
|
// The parser validates that it matches one of the accepted formats.
|
||||||
|
func lexDatetime(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
if isDigit(r) {
|
||||||
|
return lexDatetime
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case '-', 'T', ':', '.', 'Z', '+':
|
||||||
|
return lexDatetime
|
||||||
|
}
|
||||||
|
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemDatetime)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexNumberStart consumes either an integer or a float. It assumes that a sign
|
||||||
|
// has already been read, but that *no* digits have been consumed.
|
||||||
|
// lexNumberStart will move to the appropriate integer or float states.
|
||||||
|
func lexNumberStart(lx *lexer) stateFn {
|
||||||
|
// We MUST see a digit. Even floats have to start with a digit.
|
||||||
|
r := lx.next()
|
||||||
|
if !isDigit(r) {
|
||||||
|
if r == '.' {
|
||||||
|
return lx.errorf("floats must start with a digit, not '.'")
|
||||||
|
}
|
||||||
|
return lx.errorf("expected a digit but got %q", r)
|
||||||
|
}
|
||||||
|
return lexNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexNumber consumes an integer or a float after seeing the first digit.
|
||||||
|
func lexNumber(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
if isDigit(r) {
|
||||||
|
return lexNumber
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case '_':
|
||||||
|
return lexNumber
|
||||||
|
case '.', 'e', 'E':
|
||||||
|
return lexFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemInteger)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexFloat consumes the elements of a float. It allows any sequence of
|
||||||
|
// float-like characters, so floats emitted by the lexer are only a first
|
||||||
|
// approximation and must be validated by the parser.
|
||||||
|
func lexFloat(lx *lexer) stateFn {
|
||||||
|
r := lx.next()
|
||||||
|
if isDigit(r) {
|
||||||
|
return lexFloat
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case '_', '.', '-', '+', 'e', 'E':
|
||||||
|
return lexFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
lx.backup()
|
||||||
|
lx.emit(itemFloat)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexBool consumes a bool string: 'true' or 'false.
|
||||||
|
func lexBool(lx *lexer) stateFn {
|
||||||
|
var rs []rune
|
||||||
|
for {
|
||||||
|
r := lx.next()
|
||||||
|
if !unicode.IsLetter(r) {
|
||||||
|
lx.backup()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
rs = append(rs, r)
|
||||||
|
}
|
||||||
|
s := string(rs)
|
||||||
|
switch s {
|
||||||
|
case "true", "false":
|
||||||
|
lx.emit(itemBool)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
return lx.errorf("expected value but found %q instead", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexCommentStart begins the lexing of a comment. It will emit
|
||||||
|
// itemCommentStart and consume no characters, passing control to lexComment.
|
||||||
|
func lexCommentStart(lx *lexer) stateFn {
|
||||||
|
lx.ignore()
|
||||||
|
lx.emit(itemCommentStart)
|
||||||
|
return lexComment
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
|
||||||
|
// It will consume *up to* the first newline character, and pass control
|
||||||
|
// back to the last state on the stack.
|
||||||
|
func lexComment(lx *lexer) stateFn {
|
||||||
|
r := lx.peek()
|
||||||
|
if isNL(r) || r == eof {
|
||||||
|
lx.emit(itemText)
|
||||||
|
return lx.pop()
|
||||||
|
}
|
||||||
|
lx.next()
|
||||||
|
return lexComment
|
||||||
|
}
|
||||||
|
|
||||||
|
// lexSkip ignores all slurped input and moves on to the next state.
|
||||||
|
func lexSkip(lx *lexer, nextState stateFn) stateFn {
|
||||||
|
return func(lx *lexer) stateFn {
|
||||||
|
lx.ignore()
|
||||||
|
return nextState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isWhitespace returns true if `r` is a whitespace character according
|
||||||
|
// to the spec.
|
||||||
|
func isWhitespace(r rune) bool {
|
||||||
|
return r == '\t' || r == ' '
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNL(r rune) bool {
|
||||||
|
return r == '\n' || r == '\r'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDigit(r rune) bool {
|
||||||
|
return r >= '0' && r <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHexadecimal(r rune) bool {
|
||||||
|
return (r >= '0' && r <= '9') ||
|
||||||
|
(r >= 'a' && r <= 'f') ||
|
||||||
|
(r >= 'A' && r <= 'F')
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBareKeyChar(r rune) bool {
|
||||||
|
return (r >= 'A' && r <= 'Z') ||
|
||||||
|
(r >= 'a' && r <= 'z') ||
|
||||||
|
(r >= '0' && r <= '9') ||
|
||||||
|
r == '_' ||
|
||||||
|
r == '-'
|
||||||
|
}
|
||||||
|
|
||||||
|
func (itype itemType) String() string {
|
||||||
|
switch itype {
|
||||||
|
case itemError:
|
||||||
|
return "Error"
|
||||||
|
case itemNIL:
|
||||||
|
return "NIL"
|
||||||
|
case itemEOF:
|
||||||
|
return "EOF"
|
||||||
|
case itemText:
|
||||||
|
return "Text"
|
||||||
|
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
|
||||||
|
return "String"
|
||||||
|
case itemBool:
|
||||||
|
return "Bool"
|
||||||
|
case itemInteger:
|
||||||
|
return "Integer"
|
||||||
|
case itemFloat:
|
||||||
|
return "Float"
|
||||||
|
case itemDatetime:
|
||||||
|
return "DateTime"
|
||||||
|
case itemTableStart:
|
||||||
|
return "TableStart"
|
||||||
|
case itemTableEnd:
|
||||||
|
return "TableEnd"
|
||||||
|
case itemKeyStart:
|
||||||
|
return "KeyStart"
|
||||||
|
case itemArray:
|
||||||
|
return "Array"
|
||||||
|
case itemArrayEnd:
|
||||||
|
return "ArrayEnd"
|
||||||
|
case itemCommentStart:
|
||||||
|
return "CommentStart"
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (item item) String() string {
|
||||||
|
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
|
||||||
|
}
|
@ -0,0 +1,592 @@
|
|||||||
|
package toml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
type parser struct {
|
||||||
|
mapping map[string]interface{}
|
||||||
|
types map[string]tomlType
|
||||||
|
lx *lexer
|
||||||
|
|
||||||
|
// A list of keys in the order that they appear in the TOML data.
|
||||||
|
ordered []Key
|
||||||
|
|
||||||
|
// the full key for the current hash in scope
|
||||||
|
context Key
|
||||||
|
|
||||||
|
// the base key name for everything except hashes
|
||||||
|
currentKey string
|
||||||
|
|
||||||
|
// rough approximation of line number
|
||||||
|
approxLine int
|
||||||
|
|
||||||
|
// A map of 'key.group.names' to whether they were created implicitly.
|
||||||
|
implicits map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type parseError string
|
||||||
|
|
||||||
|
func (pe parseError) Error() string {
|
||||||
|
return string(pe)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parse(data string) (p *parser, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
var ok bool
|
||||||
|
if err, ok = r.(parseError); ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
p = &parser{
|
||||||
|
mapping: make(map[string]interface{}),
|
||||||
|
types: make(map[string]tomlType),
|
||||||
|
lx: lex(data),
|
||||||
|
ordered: make([]Key, 0),
|
||||||
|
implicits: make(map[string]bool),
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
item := p.next()
|
||||||
|
if item.typ == itemEOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
p.topLevel(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) panicf(format string, v ...interface{}) {
|
||||||
|
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
|
||||||
|
p.approxLine, p.current(), fmt.Sprintf(format, v...))
|
||||||
|
panic(parseError(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) next() item {
|
||||||
|
it := p.lx.nextItem()
|
||||||
|
if it.typ == itemError {
|
||||||
|
p.panicf("%s", it.val)
|
||||||
|
}
|
||||||
|
return it
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) bug(format string, v ...interface{}) {
|
||||||
|
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) expect(typ itemType) item {
|
||||||
|
it := p.next()
|
||||||
|
p.assertEqual(typ, it.typ)
|
||||||
|
return it
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) assertEqual(expected, got itemType) {
|
||||||
|
if expected != got {
|
||||||
|
p.bug("Expected '%s' but got '%s'.", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) topLevel(item item) {
|
||||||
|
switch item.typ {
|
||||||
|
case itemCommentStart:
|
||||||
|
p.approxLine = item.line
|
||||||
|
p.expect(itemText)
|
||||||
|
case itemTableStart:
|
||||||
|
kg := p.next()
|
||||||
|
p.approxLine = kg.line
|
||||||
|
|
||||||
|
var key Key
|
||||||
|
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
|
||||||
|
key = append(key, p.keyString(kg))
|
||||||
|
}
|
||||||
|
p.assertEqual(itemTableEnd, kg.typ)
|
||||||
|
|
||||||
|
p.establishContext(key, false)
|
||||||
|
p.setType("", tomlHash)
|
||||||
|
p.ordered = append(p.ordered, key)
|
||||||
|
case itemArrayTableStart:
|
||||||
|
kg := p.next()
|
||||||
|
p.approxLine = kg.line
|
||||||
|
|
||||||
|
var key Key
|
||||||
|
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
|
||||||
|
key = append(key, p.keyString(kg))
|
||||||
|
}
|
||||||
|
p.assertEqual(itemArrayTableEnd, kg.typ)
|
||||||
|
|
||||||
|
p.establishContext(key, true)
|
||||||
|
p.setType("", tomlArrayHash)
|
||||||
|
p.ordered = append(p.ordered, key)
|
||||||
|
case itemKeyStart:
|
||||||
|
kname := p.next()
|
||||||
|
p.approxLine = kname.line
|
||||||
|
p.currentKey = p.keyString(kname)
|
||||||
|
|
||||||
|
val, typ := p.value(p.next())
|
||||||
|
p.setValue(p.currentKey, val)
|
||||||
|
p.setType(p.currentKey, typ)
|
||||||
|
p.ordered = append(p.ordered, p.context.add(p.currentKey))
|
||||||
|
p.currentKey = ""
|
||||||
|
default:
|
||||||
|
p.bug("Unexpected type at top level: %s", item.typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets a string for a key (or part of a key in a table name).
|
||||||
|
func (p *parser) keyString(it item) string {
|
||||||
|
switch it.typ {
|
||||||
|
case itemText:
|
||||||
|
return it.val
|
||||||
|
case itemString, itemMultilineString,
|
||||||
|
itemRawString, itemRawMultilineString:
|
||||||
|
s, _ := p.value(it)
|
||||||
|
return s.(string)
|
||||||
|
default:
|
||||||
|
p.bug("Unexpected key type: %s", it.typ)
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// value translates an expected value from the lexer into a Go value wrapped
|
||||||
|
// as an empty interface.
|
||||||
|
func (p *parser) value(it item) (interface{}, tomlType) {
|
||||||
|
switch it.typ {
|
||||||
|
case itemString:
|
||||||
|
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
|
||||||
|
case itemMultilineString:
|
||||||
|
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
|
||||||
|
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
|
||||||
|
case itemRawString:
|
||||||
|
return it.val, p.typeOfPrimitive(it)
|
||||||
|
case itemRawMultilineString:
|
||||||
|
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
|
||||||
|
case itemBool:
|
||||||
|
switch it.val {
|
||||||
|
case "true":
|
||||||
|
return true, p.typeOfPrimitive(it)
|
||||||
|
case "false":
|
||||||
|
return false, p.typeOfPrimitive(it)
|
||||||
|
}
|
||||||
|
p.bug("Expected boolean value, but got '%s'.", it.val)
|
||||||
|
case itemInteger:
|
||||||
|
if !numUnderscoresOK(it.val) {
|
||||||
|
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
|
||||||
|
it.val)
|
||||||
|
}
|
||||||
|
val := strings.Replace(it.val, "_", "", -1)
|
||||||
|
num, err := strconv.ParseInt(val, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
// Distinguish integer values. Normally, it'd be a bug if the lexer
|
||||||
|
// provides an invalid integer, but it's possible that the number is
|
||||||
|
// out of range of valid values (which the lexer cannot determine).
|
||||||
|
// So mark the former as a bug but the latter as a legitimate user
|
||||||
|
// error.
|
||||||
|
if e, ok := err.(*strconv.NumError); ok &&
|
||||||
|
e.Err == strconv.ErrRange {
|
||||||
|
|
||||||
|
p.panicf("Integer '%s' is out of the range of 64-bit "+
|
||||||
|
"signed integers.", it.val)
|
||||||
|
} else {
|
||||||
|
p.bug("Expected integer value, but got '%s'.", it.val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num, p.typeOfPrimitive(it)
|
||||||
|
case itemFloat:
|
||||||
|
parts := strings.FieldsFunc(it.val, func(r rune) bool {
|
||||||
|
switch r {
|
||||||
|
case '.', 'e', 'E':
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
for _, part := range parts {
|
||||||
|
if !numUnderscoresOK(part) {
|
||||||
|
p.panicf("Invalid float %q: underscores must be "+
|
||||||
|
"surrounded by digits", it.val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !numPeriodsOK(it.val) {
|
||||||
|
// As a special case, numbers like '123.' or '1.e2',
|
||||||
|
// which are valid as far as Go/strconv are concerned,
|
||||||
|
// must be rejected because TOML says that a fractional
|
||||||
|
// part consists of '.' followed by 1+ digits.
|
||||||
|
p.panicf("Invalid float %q: '.' must be followed "+
|
||||||
|
"by one or more digits", it.val)
|
||||||
|
}
|
||||||
|
val := strings.Replace(it.val, "_", "", -1)
|
||||||
|
num, err := strconv.ParseFloat(val, 64)
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := err.(*strconv.NumError); ok &&
|
||||||
|
e.Err == strconv.ErrRange {
|
||||||
|
|
||||||
|
p.panicf("Float '%s' is out of the range of 64-bit "+
|
||||||
|
"IEEE-754 floating-point numbers.", it.val)
|
||||||
|
} else {
|
||||||
|
p.panicf("Invalid float value: %q", it.val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num, p.typeOfPrimitive(it)
|
||||||
|
case itemDatetime:
|
||||||
|
var t time.Time
|
||||||
|
var ok bool
|
||||||
|
var err error
|
||||||
|
for _, format := range []string{
|
||||||
|
"2006-01-02T15:04:05Z07:00",
|
||||||
|
"2006-01-02T15:04:05",
|
||||||
|
"2006-01-02",
|
||||||
|
} {
|
||||||
|
t, err = time.ParseInLocation(format, it.val, time.Local)
|
||||||
|
if err == nil {
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
p.panicf("Invalid TOML Datetime: %q.", it.val)
|
||||||
|
}
|
||||||
|
return t, p.typeOfPrimitive(it)
|
||||||
|
case itemArray:
|
||||||
|
array := make([]interface{}, 0)
|
||||||
|
types := make([]tomlType, 0)
|
||||||
|
|
||||||
|
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
|
||||||
|
if it.typ == itemCommentStart {
|
||||||
|
p.expect(itemText)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val, typ := p.value(it)
|
||||||
|
array = append(array, val)
|
||||||
|
types = append(types, typ)
|
||||||
|
}
|
||||||
|
return array, p.typeOfArray(types)
|
||||||
|
case itemInlineTableStart:
|
||||||
|
var (
|
||||||
|
hash = make(map[string]interface{})
|
||||||
|
outerContext = p.context
|
||||||
|
outerKey = p.currentKey
|
||||||
|
)
|
||||||
|
|
||||||
|
p.context = append(p.context, p.currentKey)
|
||||||
|
p.currentKey = ""
|
||||||
|
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
|
||||||
|
if it.typ != itemKeyStart {
|
||||||
|
p.bug("Expected key start but instead found %q, around line %d",
|
||||||
|
it.val, p.approxLine)
|
||||||
|
}
|
||||||
|
if it.typ == itemCommentStart {
|
||||||
|
p.expect(itemText)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve key
|
||||||
|
k := p.next()
|
||||||
|
p.approxLine = k.line
|
||||||
|
kname := p.keyString(k)
|
||||||
|
|
||||||
|
// retrieve value
|
||||||
|
p.currentKey = kname
|
||||||
|
val, typ := p.value(p.next())
|
||||||
|
// make sure we keep metadata up to date
|
||||||
|
p.setType(kname, typ)
|
||||||
|
p.ordered = append(p.ordered, p.context.add(p.currentKey))
|
||||||
|
hash[kname] = val
|
||||||
|
}
|
||||||
|
p.context = outerContext
|
||||||
|
p.currentKey = outerKey
|
||||||
|
return hash, tomlHash
|
||||||
|
}
|
||||||
|
p.bug("Unexpected value type: %s", it.typ)
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// numUnderscoresOK checks whether each underscore in s is surrounded by
|
||||||
|
// characters that are not underscores.
|
||||||
|
func numUnderscoresOK(s string) bool {
|
||||||
|
accept := false
|
||||||
|
for _, r := range s {
|
||||||
|
if r == '_' {
|
||||||
|
if !accept {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
accept = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
accept = true
|
||||||
|
}
|
||||||
|
return accept
|
||||||
|
}
|
||||||
|
|
||||||
|
// numPeriodsOK checks whether every period in s is followed by a digit.
|
||||||
|
func numPeriodsOK(s string) bool {
|
||||||
|
period := false
|
||||||
|
for _, r := range s {
|
||||||
|
if period && !isDigit(r) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
period = r == '.'
|
||||||
|
}
|
||||||
|
return !period
|
||||||
|
}
|
||||||
|
|
||||||
|
// establishContext sets the current context of the parser,
|
||||||
|
// where the context is either a hash or an array of hashes. Which one is
|
||||||
|
// set depends on the value of the `array` parameter.
|
||||||
|
//
|
||||||
|
// Establishing the context also makes sure that the key isn't a duplicate, and
|
||||||
|
// will create implicit hashes automatically.
|
||||||
|
func (p *parser) establishContext(key Key, array bool) {
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// Always start at the top level and drill down for our context.
|
||||||
|
hashContext := p.mapping
|
||||||
|
keyContext := make(Key, 0)
|
||||||
|
|
||||||
|
// We only need implicit hashes for key[0:-1]
|
||||||
|
for _, k := range key[0 : len(key)-1] {
|
||||||
|
_, ok = hashContext[k]
|
||||||
|
keyContext = append(keyContext, k)
|
||||||
|
|
||||||
|
// No key? Make an implicit hash and move on.
|
||||||
|
if !ok {
|
||||||
|
p.addImplicit(keyContext)
|
||||||
|
hashContext[k] = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the hash context is actually an array of tables, then set
|
||||||
|
// the hash context to the last element in that array.
|
||||||
|
//
|
||||||
|
// Otherwise, it better be a table, since this MUST be a key group (by
|
||||||
|
// virtue of it not being the last element in a key).
|
||||||
|
switch t := hashContext[k].(type) {
|
||||||
|
case []map[string]interface{}:
|
||||||
|
hashContext = t[len(t)-1]
|
||||||
|
case map[string]interface{}:
|
||||||
|
hashContext = t
|
||||||
|
default:
|
||||||
|
p.panicf("Key '%s' was already created as a hash.", keyContext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.context = keyContext
|
||||||
|
if array {
|
||||||
|
// If this is the first element for this array, then allocate a new
|
||||||
|
// list of tables for it.
|
||||||
|
k := key[len(key)-1]
|
||||||
|
if _, ok := hashContext[k]; !ok {
|
||||||
|
hashContext[k] = make([]map[string]interface{}, 0, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a new table. But make sure the key hasn't already been used
|
||||||
|
// for something else.
|
||||||
|
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
|
||||||
|
hashContext[k] = append(hash, make(map[string]interface{}))
|
||||||
|
} else {
|
||||||
|
p.panicf("Key '%s' was already created and cannot be used as "+
|
||||||
|
"an array.", keyContext)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
p.setValue(key[len(key)-1], make(map[string]interface{}))
|
||||||
|
}
|
||||||
|
p.context = append(p.context, key[len(key)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// setValue sets the given key to the given value in the current context.
|
||||||
|
// It will make sure that the key hasn't already been defined, account for
|
||||||
|
// implicit key groups.
|
||||||
|
func (p *parser) setValue(key string, value interface{}) {
|
||||||
|
var tmpHash interface{}
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
hash := p.mapping
|
||||||
|
keyContext := make(Key, 0)
|
||||||
|
for _, k := range p.context {
|
||||||
|
keyContext = append(keyContext, k)
|
||||||
|
if tmpHash, ok = hash[k]; !ok {
|
||||||
|
p.bug("Context for key '%s' has not been established.", keyContext)
|
||||||
|
}
|
||||||
|
switch t := tmpHash.(type) {
|
||||||
|
case []map[string]interface{}:
|
||||||
|
// The context is a table of hashes. Pick the most recent table
|
||||||
|
// defined as the current hash.
|
||||||
|
hash = t[len(t)-1]
|
||||||
|
case map[string]interface{}:
|
||||||
|
hash = t
|
||||||
|
default:
|
||||||
|
p.bug("Expected hash to have type 'map[string]interface{}', but "+
|
||||||
|
"it has '%T' instead.", tmpHash)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
keyContext = append(keyContext, key)
|
||||||
|
|
||||||
|
if _, ok := hash[key]; ok {
|
||||||
|
// Typically, if the given key has already been set, then we have
|
||||||
|
// to raise an error since duplicate keys are disallowed. However,
|
||||||
|
// it's possible that a key was previously defined implicitly. In this
|
||||||
|
// case, it is allowed to be redefined concretely. (See the
|
||||||
|
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
|
||||||
|
//
|
||||||
|
// But we have to make sure to stop marking it as an implicit. (So that
|
||||||
|
// another redefinition provokes an error.)
|
||||||
|
//
|
||||||
|
// Note that since it has already been defined (as a hash), we don't
|
||||||
|
// want to overwrite it. So our business is done.
|
||||||
|
if p.isImplicit(keyContext) {
|
||||||
|
p.removeImplicit(keyContext)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we have a concrete key trying to override a previous
|
||||||
|
// key, which is *always* wrong.
|
||||||
|
p.panicf("Key '%s' has already been defined.", keyContext)
|
||||||
|
}
|
||||||
|
hash[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// setType sets the type of a particular value at a given key.
|
||||||
|
// It should be called immediately AFTER setValue.
|
||||||
|
//
|
||||||
|
// Note that if `key` is empty, then the type given will be applied to the
|
||||||
|
// current context (which is either a table or an array of tables).
|
||||||
|
func (p *parser) setType(key string, typ tomlType) {
|
||||||
|
keyContext := make(Key, 0, len(p.context)+1)
|
||||||
|
for _, k := range p.context {
|
||||||
|
keyContext = append(keyContext, k)
|
||||||
|
}
|
||||||
|
if len(key) > 0 { // allow type setting for hashes
|
||||||
|
keyContext = append(keyContext, key)
|
||||||
|
}
|
||||||
|
p.types[keyContext.String()] = typ
|
||||||
|
}
|
||||||
|
|
||||||
|
// addImplicit sets the given Key as having been created implicitly.
|
||||||
|
func (p *parser) addImplicit(key Key) {
|
||||||
|
p.implicits[key.String()] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeImplicit stops tagging the given key as having been implicitly
|
||||||
|
// created.
|
||||||
|
func (p *parser) removeImplicit(key Key) {
|
||||||
|
p.implicits[key.String()] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isImplicit returns true if the key group pointed to by the key was created
|
||||||
|
// implicitly.
|
||||||
|
func (p *parser) isImplicit(key Key) bool {
|
||||||
|
return p.implicits[key.String()]
|
||||||
|
}
|
||||||
|
|
||||||
|
// current returns the full key name of the current context.
|
||||||
|
func (p *parser) current() string {
|
||||||
|
if len(p.currentKey) == 0 {
|
||||||
|
return p.context.String()
|
||||||
|
}
|
||||||
|
if len(p.context) == 0 {
|
||||||
|
return p.currentKey
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripFirstNewline(s string) string {
|
||||||
|
if len(s) == 0 || s[0] != '\n' {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripEscapedWhitespace(s string) string {
|
||||||
|
esc := strings.Split(s, "\\\n")
|
||||||
|
if len(esc) > 1 {
|
||||||
|
for i := 1; i < len(esc); i++ {
|
||||||
|
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(esc, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) replaceEscapes(str string) string {
|
||||||
|
var replaced []rune
|
||||||
|
s := []byte(str)
|
||||||
|
r := 0
|
||||||
|
for r < len(s) {
|
||||||
|
if s[r] != '\\' {
|
||||||
|
c, size := utf8.DecodeRune(s[r:])
|
||||||
|
r += size
|
||||||
|
replaced = append(replaced, c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r += 1
|
||||||
|
if r >= len(s) {
|
||||||
|
p.bug("Escape sequence at end of string.")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch s[r] {
|
||||||
|
default:
|
||||||
|
p.bug("Expected valid escape code after \\, but got %q.", s[r])
|
||||||
|
return ""
|
||||||
|
case 'b':
|
||||||
|
replaced = append(replaced, rune(0x0008))
|
||||||
|
r += 1
|
||||||
|
case 't':
|
||||||
|
replaced = append(replaced, rune(0x0009))
|
||||||
|
r += 1
|
||||||
|
case 'n':
|
||||||
|
replaced = append(replaced, rune(0x000A))
|
||||||
|
r += 1
|
||||||
|
case 'f':
|
||||||
|
replaced = append(replaced, rune(0x000C))
|
||||||
|
r += 1
|
||||||
|
case 'r':
|
||||||
|
replaced = append(replaced, rune(0x000D))
|
||||||
|
r += 1
|
||||||
|
case '"':
|
||||||
|
replaced = append(replaced, rune(0x0022))
|
||||||
|
r += 1
|
||||||
|
case '\\':
|
||||||
|
replaced = append(replaced, rune(0x005C))
|
||||||
|
r += 1
|
||||||
|
case 'u':
|
||||||
|
// At this point, we know we have a Unicode escape of the form
|
||||||
|
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
|
||||||
|
// for us.)
|
||||||
|
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
|
||||||
|
replaced = append(replaced, escaped)
|
||||||
|
r += 5
|
||||||
|
case 'U':
|
||||||
|
// At this point, we know we have a Unicode escape of the form
|
||||||
|
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
|
||||||
|
// for us.)
|
||||||
|
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
|
||||||
|
replaced = append(replaced, escaped)
|
||||||
|
r += 9
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(replaced)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
|
||||||
|
s := string(bs)
|
||||||
|
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
|
||||||
|
"lexer claims it's OK: %s", s, err)
|
||||||
|
}
|
||||||
|
if !utf8.ValidRune(rune(hex)) {
|
||||||
|
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
|
||||||
|
}
|
||||||
|
return rune(hex)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isStringType(ty itemType) bool {
|
||||||
|
return ty == itemString || ty == itemMultilineString ||
|
||||||
|
ty == itemRawString || ty == itemRawMultilineString
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
au BufWritePost *.go silent!make tags > /dev/null 2>&1
|
@ -0,0 +1,91 @@
|
|||||||
|
package toml
|
||||||
|
|
||||||
|
// tomlType represents any Go type that corresponds to a TOML type.
|
||||||
|
// While the first draft of the TOML spec has a simplistic type system that
|
||||||
|
// probably doesn't need this level of sophistication, we seem to be militating
|
||||||
|
// toward adding real composite types.
|
||||||
|
type tomlType interface {
|
||||||
|
typeString() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// typeEqual accepts any two types and returns true if they are equal.
|
||||||
|
func typeEqual(t1, t2 tomlType) bool {
|
||||||
|
if t1 == nil || t2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t1.typeString() == t2.typeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeIsHash(t tomlType) bool {
|
||||||
|
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
type tomlBaseType string
|
||||||
|
|
||||||
|
func (btype tomlBaseType) typeString() string {
|
||||||
|
return string(btype)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (btype tomlBaseType) String() string {
|
||||||
|
return btype.typeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
tomlInteger tomlBaseType = "Integer"
|
||||||
|
tomlFloat tomlBaseType = "Float"
|
||||||
|
tomlDatetime tomlBaseType = "Datetime"
|
||||||
|
tomlString tomlBaseType = "String"
|
||||||
|
tomlBool tomlBaseType = "Bool"
|
||||||
|
tomlArray tomlBaseType = "Array"
|
||||||
|
tomlHash tomlBaseType = "Hash"
|
||||||
|
tomlArrayHash tomlBaseType = "ArrayHash"
|
||||||
|
)
|
||||||
|
|
||||||
|
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
|
||||||
|
// Primitive values are: Integer, Float, Datetime, String and Bool.
|
||||||
|
//
|
||||||
|
// Passing a lexer item other than the following will cause a BUG message
|
||||||
|
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
|
||||||
|
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
|
||||||
|
switch lexItem.typ {
|
||||||
|
case itemInteger:
|
||||||
|
return tomlInteger
|
||||||
|
case itemFloat:
|
||||||
|
return tomlFloat
|
||||||
|
case itemDatetime:
|
||||||
|
return tomlDatetime
|
||||||
|
case itemString:
|
||||||
|
return tomlString
|
||||||
|
case itemMultilineString:
|
||||||
|
return tomlString
|
||||||
|
case itemRawString:
|
||||||
|
return tomlString
|
||||||
|
case itemRawMultilineString:
|
||||||
|
return tomlString
|
||||||
|
case itemBool:
|
||||||
|
return tomlBool
|
||||||
|
}
|
||||||
|
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// typeOfArray returns a tomlType for an array given a list of types of its
|
||||||
|
// values.
|
||||||
|
//
|
||||||
|
// In the current spec, if an array is homogeneous, then its type is always
|
||||||
|
// "Array". If the array is not homogeneous, an error is generated.
|
||||||
|
func (p *parser) typeOfArray(types []tomlType) tomlType {
|
||||||
|
// Empty arrays are cool.
|
||||||
|
if len(types) == 0 {
|
||||||
|
return tomlArray
|
||||||
|
}
|
||||||
|
|
||||||
|
theType := types[0]
|
||||||
|
for _, t := range types[1:] {
|
||||||
|
if !typeEqual(theType, t) {
|
||||||
|
p.panicf("Array contains values of type '%s' and '%s', but "+
|
||||||
|
"arrays must be homogeneous.", theType, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tomlArray
|
||||||
|
}
|
@ -0,0 +1,242 @@
|
|||||||
|
package toml
|
||||||
|
|
||||||
|
// Struct field handling is adapted from code in encoding/json:
|
||||||
|
//
|
||||||
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the Go distribution.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A field represents a single field found in a struct.
|
||||||
|
type field struct {
|
||||||
|
name string // the name of the field (`toml` tag included)
|
||||||
|
tag bool // whether field has a `toml` tag
|
||||||
|
index []int // represents the depth of an anonymous field
|
||||||
|
typ reflect.Type // the type of the field
|
||||||
|
}
|
||||||
|
|
||||||
|
// byName sorts field by name, breaking ties with depth,
|
||||||
|
// then breaking ties with "name came from toml tag", then
|
||||||
|
// breaking ties with index sequence.
|
||||||
|
type byName []field
|
||||||
|
|
||||||
|
func (x byName) Len() int { return len(x) }
|
||||||
|
|
||||||
|
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
|
||||||
|
|
||||||
|
func (x byName) Less(i, j int) bool {
|
||||||
|
if x[i].name != x[j].name {
|
||||||
|
return x[i].name < x[j].name
|
||||||
|
}
|
||||||
|
if len(x[i].index) != len(x[j].index) {
|
||||||
|
return len(x[i].index) < len(x[j].index)
|
||||||
|
}
|
||||||
|
if x[i].tag != x[j].tag {
|
||||||
|
return x[i].tag
|
||||||
|
}
|
||||||
|
return byIndex(x).Less(i, j)
|
||||||
|
}
|
||||||
|
|
||||||
|
// byIndex sorts field by index sequence.
|
||||||
|
type byIndex []field
|
||||||
|
|
||||||
|
func (x byIndex) Len() int { return len(x) }
|
||||||
|
|
||||||
|
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
|
||||||
|
|
||||||
|
func (x byIndex) Less(i, j int) bool {
|
||||||
|
for k, xik := range x[i].index {
|
||||||
|
if k >= len(x[j].index) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if xik != x[j].index[k] {
|
||||||
|
return xik < x[j].index[k]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(x[i].index) < len(x[j].index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// typeFields returns a list of fields that TOML should recognize for the given
|
||||||
|
// type. The algorithm is breadth-first search over the set of structs to
|
||||||
|
// include - the top struct and then any reachable anonymous structs.
|
||||||
|
func typeFields(t reflect.Type) []field {
|
||||||
|
// Anonymous fields to explore at the current level and the next.
|
||||||
|
current := []field{}
|
||||||
|
next := []field{{typ: t}}
|
||||||
|
|
||||||
|
// Count of queued names for current level and the next.
|
||||||
|
count := map[reflect.Type]int{}
|
||||||
|
nextCount := map[reflect.Type]int{}
|
||||||
|
|
||||||
|
// Types already visited at an earlier level.
|
||||||
|
visited := map[reflect.Type]bool{}
|
||||||
|
|
||||||
|
// Fields found.
|
||||||
|
var fields []field
|
||||||
|
|
||||||
|
for len(next) > 0 {
|
||||||
|
current, next = next, current[:0]
|
||||||
|
count, nextCount = nextCount, map[reflect.Type]int{}
|
||||||
|
|
||||||
|
for _, f := range current {
|
||||||
|
if visited[f.typ] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visited[f.typ] = true
|
||||||
|
|
||||||
|
// Scan f.typ for fields to include.
|
||||||
|
for i := 0; i < f.typ.NumField(); i++ {
|
||||||
|
sf := f.typ.Field(i)
|
||||||
|
if sf.PkgPath != "" && !sf.Anonymous { // unexported
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opts := getOptions(sf.Tag)
|
||||||
|
if opts.skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
index := make([]int, len(f.index)+1)
|
||||||
|
copy(index, f.index)
|
||||||
|
index[len(f.index)] = i
|
||||||
|
|
||||||
|
ft := sf.Type
|
||||||
|
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
|
||||||
|
// Follow pointer.
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record found field and index sequence.
|
||||||
|
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
|
||||||
|
tagged := opts.name != ""
|
||||||
|
name := opts.name
|
||||||
|
if name == "" {
|
||||||
|
name = sf.Name
|
||||||
|
}
|
||||||
|
fields = append(fields, field{name, tagged, index, ft})
|
||||||
|
if count[f.typ] > 1 {
|
||||||
|
// If there were multiple instances, add a second,
|
||||||
|
// so that the annihilation code will see a duplicate.
|
||||||
|
// It only cares about the distinction between 1 or 2,
|
||||||
|
// so don't bother generating any more copies.
|
||||||
|
fields = append(fields, fields[len(fields)-1])
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record new anonymous struct to explore in next round.
|
||||||
|
nextCount[ft]++
|
||||||
|
if nextCount[ft] == 1 {
|
||||||
|
f := field{name: ft.Name(), index: index, typ: ft}
|
||||||
|
next = append(next, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(byName(fields))
|
||||||
|
|
||||||
|
// Delete all fields that are hidden by the Go rules for embedded fields,
|
||||||
|
// except that fields with TOML tags are promoted.
|
||||||
|
|
||||||
|
// The fields are sorted in primary order of name, secondary order
|
||||||
|
// of field index length. Loop over names; for each name, delete
|
||||||
|
// hidden fields by choosing the one dominant field that survives.
|
||||||
|
out := fields[:0]
|
||||||
|
for advance, i := 0, 0; i < len(fields); i += advance {
|
||||||
|
// One iteration per name.
|
||||||
|
// Find the sequence of fields with the name of this first field.
|
||||||
|
fi := fields[i]
|
||||||
|
name := fi.name
|
||||||
|
for advance = 1; i+advance < len(fields); advance++ {
|
||||||
|
fj := fields[i+advance]
|
||||||
|
if fj.name != name {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if advance == 1 { // Only one field with this name
|
||||||
|
out = append(out, fi)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dominant, ok := dominantField(fields[i : i+advance])
|
||||||
|
if ok {
|
||||||
|
out = append(out, dominant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fields = out
|
||||||
|
sort.Sort(byIndex(fields))
|
||||||
|
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// dominantField looks through the fields, all of which are known to
|
||||||
|
// have the same name, to find the single field that dominates the
|
||||||
|
// others using Go's embedding rules, modified by the presence of
|
||||||
|
// TOML tags. If there are multiple top-level fields, the boolean
|
||||||
|
// will be false: This condition is an error in Go and we skip all
|
||||||
|
// the fields.
|
||||||
|
func dominantField(fields []field) (field, bool) {
|
||||||
|
// The fields are sorted in increasing index-length order. The winner
|
||||||
|
// must therefore be one with the shortest index length. Drop all
|
||||||
|
// longer entries, which is easy: just truncate the slice.
|
||||||
|
length := len(fields[0].index)
|
||||||
|
tagged := -1 // Index of first tagged field.
|
||||||
|
for i, f := range fields {
|
||||||
|
if len(f.index) > length {
|
||||||
|
fields = fields[:i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if f.tag {
|
||||||
|
if tagged >= 0 {
|
||||||
|
// Multiple tagged fields at the same level: conflict.
|
||||||
|
// Return no field.
|
||||||
|
return field{}, false
|
||||||
|
}
|
||||||
|
tagged = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tagged >= 0 {
|
||||||
|
return fields[tagged], true
|
||||||
|
}
|
||||||
|
// All remaining fields have the same length. If there's more than one,
|
||||||
|
// we have a conflict (two fields named "X" at the same level) and we
|
||||||
|
// return no field.
|
||||||
|
if len(fields) > 1 {
|
||||||
|
return field{}, false
|
||||||
|
}
|
||||||
|
return fields[0], true
|
||||||
|
}
|
||||||
|
|
||||||
|
var fieldCache struct {
|
||||||
|
sync.RWMutex
|
||||||
|
m map[reflect.Type][]field
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
|
||||||
|
func cachedTypeFields(t reflect.Type) []field {
|
||||||
|
fieldCache.RLock()
|
||||||
|
f := fieldCache.m[t]
|
||||||
|
fieldCache.RUnlock()
|
||||||
|
if f != nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute fields without lock.
|
||||||
|
// Might duplicate effort but won't hold other computations back.
|
||||||
|
f = typeFields(t)
|
||||||
|
if f == nil {
|
||||||
|
f = []field{}
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldCache.Lock()
|
||||||
|
if fieldCache.m == nil {
|
||||||
|
fieldCache.m = map[reflect.Type][]field{}
|
||||||
|
}
|
||||||
|
fieldCache.m[t] = f
|
||||||
|
fieldCache.Unlock()
|
||||||
|
return f
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
.DS_Store
|
||||||
|
bin
|
||||||
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
script:
|
||||||
|
- go vet ./...
|
||||||
|
- go test -v ./...
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.3
|
||||||
|
- 1.4
|
||||||
|
- 1.5
|
||||||
|
- 1.6
|
||||||
|
- 1.7
|
||||||
|
- tip
|
@ -0,0 +1,8 @@
|
|||||||
|
Copyright (c) 2012 Dave Grijalva
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
@ -0,0 +1,97 @@
|
|||||||
|
## Migration Guide from v2 -> v3
|
||||||
|
|
||||||
|
Version 3 adds several new, frequently requested features. To do so, it introduces a few breaking changes. We've worked to keep these as minimal as possible. This guide explains the breaking changes and how you can quickly update your code.
|
||||||
|
|
||||||
|
### `Token.Claims` is now an interface type
|
||||||
|
|
||||||
|
The most requested feature from the 2.0 verison of this library was the ability to provide a custom type to the JSON parser for claims. This was implemented by introducing a new interface, `Claims`, to replace `map[string]interface{}`. We also included two concrete implementations of `Claims`: `MapClaims` and `StandardClaims`.
|
||||||
|
|
||||||
|
`MapClaims` is an alias for `map[string]interface{}` with built in validation behavior. It is the default claims type when using `Parse`. The usage is unchanged except you must type cast the claims property.
|
||||||
|
|
||||||
|
The old example for parsing a token looked like this..
|
||||||
|
|
||||||
|
```go
|
||||||
|
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
|
||||||
|
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
is now directly mapped to...
|
||||||
|
|
||||||
|
```go
|
||||||
|
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
|
||||||
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`StandardClaims` is designed to be embedded in your custom type. You can supply a custom claims type with the new `ParseWithClaims` function. Here's an example of using a custom claims type.
|
||||||
|
|
||||||
|
```go
|
||||||
|
type MyCustomClaims struct {
|
||||||
|
User string
|
||||||
|
*StandardClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
if token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, keyLookupFunc); err == nil {
|
||||||
|
claims := token.Claims.(*MyCustomClaims)
|
||||||
|
fmt.Printf("Token for user %v expires %v", claims.User, claims.StandardClaims.ExpiresAt)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### `ParseFromRequest` has been moved
|
||||||
|
|
||||||
|
To keep this library focused on the tokens without becoming overburdened with complex request processing logic, `ParseFromRequest` and its new companion `ParseFromRequestWithClaims` have been moved to a subpackage, `request`. The method signatues have also been augmented to receive a new argument: `Extractor`.
|
||||||
|
|
||||||
|
`Extractors` do the work of picking the token string out of a request. The interface is simple and composable.
|
||||||
|
|
||||||
|
This simple parsing example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
if token, err := jwt.ParseFromRequest(tokenString, req, keyLookupFunc); err == nil {
|
||||||
|
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
is directly mapped to:
|
||||||
|
|
||||||
|
```go
|
||||||
|
if token, err := request.ParseFromRequest(req, request.OAuth2Extractor, keyLookupFunc); err == nil {
|
||||||
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
There are several concrete `Extractor` types provided for your convenience:
|
||||||
|
|
||||||
|
* `HeaderExtractor` will search a list of headers until one contains content.
|
||||||
|
* `ArgumentExtractor` will search a list of keys in request query and form arguments until one contains content.
|
||||||
|
* `MultiExtractor` will try a list of `Extractors` in order until one returns content.
|
||||||
|
* `AuthorizationHeaderExtractor` will look in the `Authorization` header for a `Bearer` token.
|
||||||
|
* `OAuth2Extractor` searches the places an OAuth2 token would be specified (per the spec): `Authorization` header and `access_token` argument
|
||||||
|
* `PostExtractionFilter` wraps an `Extractor`, allowing you to process the content before it's parsed. A simple example is stripping the `Bearer ` text from a header
|
||||||
|
|
||||||
|
|
||||||
|
### RSA signing methods no longer accept `[]byte` keys
|
||||||
|
|
||||||
|
Due to a [critical vulnerability](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/), we've decided the convenience of accepting `[]byte` instead of `rsa.PublicKey` or `rsa.PrivateKey` isn't worth the risk of misuse.
|
||||||
|
|
||||||
|
To replace this behavior, we've added two helper methods: `ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error)` and `ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error)`. These are just simple helpers for unpacking PEM encoded PKCS1 and PKCS8 keys. If your keys are encoded any other way, all you need to do is convert them to the `crypto/rsa` package's types.
|
||||||
|
|
||||||
|
```go
|
||||||
|
func keyLookupFunc(*Token) (interface{}, error) {
|
||||||
|
// Don't forget to validate the alg is what you expect:
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up key
|
||||||
|
key, err := lookupPublicKey(token.Header["kid"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unpack key from PEM encoded PKCS8
|
||||||
|
return jwt.ParseRSAPublicKeyFromPEM(key)
|
||||||
|
}
|
||||||
|
```
|
@ -0,0 +1,100 @@
|
|||||||
|
# jwt-go
|
||||||
|
|
||||||
|
[](https://travis-ci.org/dgrijalva/jwt-go)
|
||||||
|
[](https://godoc.org/github.com/dgrijalva/jwt-go)
|
||||||
|
|
||||||
|
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
|
||||||
|
|
||||||
|
**NEW VERSION COMING:** There have been a lot of improvements suggested since the version 3.0.0 released in 2016. I'm working now on cutting two different releases: 3.2.0 will contain any non-breaking changes or enhancements. 4.0.0 will follow shortly which will include breaking changes. See the 4.0.0 milestone to get an idea of what's coming. If you have other ideas, or would like to participate in 4.0.0, now's the time. If you depend on this library and don't want to be interrupted, I recommend you use your dependency mangement tool to pin to version 3.
|
||||||
|
|
||||||
|
**SECURITY NOTICE:** Some older versions of Go have a security issue in the cryotp/elliptic. Recommendation is to upgrade to at least 1.8.3. See issue #216 for more detail.
|
||||||
|
|
||||||
|
**SECURITY NOTICE:** It's important that you [validate the `alg` presented is what you expect](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key types match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided.
|
||||||
|
|
||||||
|
## What the heck is a JWT?
|
||||||
|
|
||||||
|
JWT.io has [a great introduction](https://jwt.io/introduction) to JSON Web Tokens.
|
||||||
|
|
||||||
|
In short, it's a signed JSON object that does something useful (for example, authentication). It's commonly used for `Bearer` tokens in Oauth 2. A token is made of three parts, separated by `.`'s. The first two parts are JSON objects, that have been [base64url](http://tools.ietf.org/html/rfc4648) encoded. The last part is the signature, encoded the same way.
|
||||||
|
|
||||||
|
The first part is called the header. It contains the necessary information for verifying the last part, the signature. For example, which encryption method was used for signing and what key was used.
|
||||||
|
|
||||||
|
The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-jones-json-web-token.html) for information about reserved keys and the proper way to add your own.
|
||||||
|
|
||||||
|
## What's in the box?
|
||||||
|
|
||||||
|
This library supports the parsing and verification as well as the generation and signing of JWTs. Current supported signing algorithms are HMAC SHA, RSA, RSA-PSS, and ECDSA, though hooks are present for adding your own.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) for examples of usage:
|
||||||
|
|
||||||
|
* [Simple example of parsing and validating a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-Parse--Hmac)
|
||||||
|
* [Simple example of building and signing a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-New--Hmac)
|
||||||
|
* [Directory of Examples](https://godoc.org/github.com/dgrijalva/jwt-go#pkg-examples)
|
||||||
|
|
||||||
|
## Extensions
|
||||||
|
|
||||||
|
This library publishes all the necessary components for adding your own signing methods. Simply implement the `SigningMethod` interface and register a factory method using `RegisterSigningMethod`.
|
||||||
|
|
||||||
|
Here's an example of an extension that integrates with the Google App Engine signing tools: https://github.com/someone1/gcp-jwt-go
|
||||||
|
|
||||||
|
## Compliance
|
||||||
|
|
||||||
|
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
|
||||||
|
|
||||||
|
* In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.
|
||||||
|
|
||||||
|
## Project Status & Versioning
|
||||||
|
|
||||||
|
This library is considered production ready. Feedback and feature requests are appreciated. The API should be considered stable. There should be very few backwards-incompatible changes outside of major version updates (and only with good reason).
|
||||||
|
|
||||||
|
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
|
||||||
|
|
||||||
|
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v3`. It will do the right thing WRT semantic versioning.
|
||||||
|
|
||||||
|
**BREAKING CHANGES:***
|
||||||
|
* Version 3.0.0 includes _a lot_ of changes from the 2.x line, including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
|
||||||
|
|
||||||
|
## Usage Tips
|
||||||
|
|
||||||
|
### Signing vs Encryption
|
||||||
|
|
||||||
|
A token is simply a JSON object that is signed by its author. this tells you exactly two things about the data:
|
||||||
|
|
||||||
|
* The author of the token was in the possession of the signing secret
|
||||||
|
* The data has not been modified since it was signed
|
||||||
|
|
||||||
|
It's important to know that JWT does not provide encryption, which means anyone who has access to the token can read its contents. If you need to protect (encrypt) the data, there is a companion spec, `JWE`, that provides this functionality. JWE is currently outside the scope of this library.
|
||||||
|
|
||||||
|
### Choosing a Signing Method
|
||||||
|
|
||||||
|
There are several signing methods available, and you should probably take the time to learn about the various options before choosing one. The principal design decision is most likely going to be symmetric vs asymmetric.
|
||||||
|
|
||||||
|
Symmetric signing methods, such as HSA, use only a single secret. This is probably the simplest signing method to use since any `[]byte` can be used as a valid secret. They are also slightly computationally faster to use, though this rarely is enough to matter. Symmetric signing methods work the best when both producers and consumers of tokens are trusted, or even the same system. Since the same secret is used to both sign and validate tokens, you can't easily distribute the key for validation.
|
||||||
|
|
||||||
|
Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification.
|
||||||
|
|
||||||
|
### Signing Methods and Key Types
|
||||||
|
|
||||||
|
Each signing method expects a different object type for its signing keys. See the package documentation for details. Here are the most common ones:
|
||||||
|
|
||||||
|
* The [HMAC signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodHMAC) (`HS256`,`HS384`,`HS512`) expect `[]byte` values for signing and validation
|
||||||
|
* The [RSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodRSA) (`RS256`,`RS384`,`RS512`) expect `*rsa.PrivateKey` for signing and `*rsa.PublicKey` for validation
|
||||||
|
* The [ECDSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodECDSA) (`ES256`,`ES384`,`ES512`) expect `*ecdsa.PrivateKey` for signing and `*ecdsa.PublicKey` for validation
|
||||||
|
|
||||||
|
### JWT and OAuth
|
||||||
|
|
||||||
|
It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication.
|
||||||
|
|
||||||
|
Without going too far down the rabbit hole, here's a description of the interaction of these technologies:
|
||||||
|
|
||||||
|
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
|
||||||
|
* OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token.
|
||||||
|
* Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL.
|
||||||
|
|
||||||
|
## More
|
||||||
|
|
||||||
|
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
|
||||||
|
|
||||||
|
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in the documentation.
|
@ -0,0 +1,118 @@
|
|||||||
|
## `jwt-go` Version History
|
||||||
|
|
||||||
|
#### 3.2.0
|
||||||
|
|
||||||
|
* Added method `ParseUnverified` to allow users to split up the tasks of parsing and validation
|
||||||
|
* HMAC signing method returns `ErrInvalidKeyType` instead of `ErrInvalidKey` where appropriate
|
||||||
|
* Added options to `request.ParseFromRequest`, which allows for an arbitrary list of modifiers to parsing behavior. Initial set include `WithClaims` and `WithParser`. Existing usage of this function will continue to work as before.
|
||||||
|
* Deprecated `ParseFromRequestWithClaims` to simplify API in the future.
|
||||||
|
|
||||||
|
#### 3.1.0
|
||||||
|
|
||||||
|
* Improvements to `jwt` command line tool
|
||||||
|
* Added `SkipClaimsValidation` option to `Parser`
|
||||||
|
* Documentation updates
|
||||||
|
|
||||||
|
#### 3.0.0
|
||||||
|
|
||||||
|
* **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code
|
||||||
|
* Dropped support for `[]byte` keys when using RSA signing methods. This convenience feature could contribute to security vulnerabilities involving mismatched key types with signing methods.
|
||||||
|
* `ParseFromRequest` has been moved to `request` subpackage and usage has changed
|
||||||
|
* The `Claims` property on `Token` is now type `Claims` instead of `map[string]interface{}`. The default value is type `MapClaims`, which is an alias to `map[string]interface{}`. This makes it possible to use a custom type when decoding claims.
|
||||||
|
* Other Additions and Changes
|
||||||
|
* Added `Claims` interface type to allow users to decode the claims into a custom type
|
||||||
|
* Added `ParseWithClaims`, which takes a third argument of type `Claims`. Use this function instead of `Parse` if you have a custom type you'd like to decode into.
|
||||||
|
* Dramatically improved the functionality and flexibility of `ParseFromRequest`, which is now in the `request` subpackage
|
||||||
|
* Added `ParseFromRequestWithClaims` which is the `FromRequest` equivalent of `ParseWithClaims`
|
||||||
|
* Added new interface type `Extractor`, which is used for extracting JWT strings from http requests. Used with `ParseFromRequest` and `ParseFromRequestWithClaims`.
|
||||||
|
* Added several new, more specific, validation errors to error type bitmask
|
||||||
|
* Moved examples from README to executable example files
|
||||||
|
* Signing method registry is now thread safe
|
||||||
|
* Added new property to `ValidationError`, which contains the raw error returned by calls made by parse/verify (such as those returned by keyfunc or json parser)
|
||||||
|
|
||||||
|
#### 2.7.0
|
||||||
|
|
||||||
|
This will likely be the last backwards compatible release before 3.0.0, excluding essential bug fixes.
|
||||||
|
|
||||||
|
* Added new option `-show` to the `jwt` command that will just output the decoded token without verifying
|
||||||
|
* Error text for expired tokens includes how long it's been expired
|
||||||
|
* Fixed incorrect error returned from `ParseRSAPublicKeyFromPEM`
|
||||||
|
* Documentation updates
|
||||||
|
|
||||||
|
#### 2.6.0
|
||||||
|
|
||||||
|
* Exposed inner error within ValidationError
|
||||||
|
* Fixed validation errors when using UseJSONNumber flag
|
||||||
|
* Added several unit tests
|
||||||
|
|
||||||
|
#### 2.5.0
|
||||||
|
|
||||||
|
* Added support for signing method none. You shouldn't use this. The API tries to make this clear.
|
||||||
|
* Updated/fixed some documentation
|
||||||
|
* Added more helpful error message when trying to parse tokens that begin with `BEARER `
|
||||||
|
|
||||||
|
#### 2.4.0
|
||||||
|
|
||||||
|
* Added new type, Parser, to allow for configuration of various parsing parameters
|
||||||
|
* You can now specify a list of valid signing methods. Anything outside this set will be rejected.
|
||||||
|
* You can now opt to use the `json.Number` type instead of `float64` when parsing token JSON
|
||||||
|
* Added support for [Travis CI](https://travis-ci.org/dgrijalva/jwt-go)
|
||||||
|
* Fixed some bugs with ECDSA parsing
|
||||||
|
|
||||||
|
#### 2.3.0
|
||||||
|
|
||||||
|
* Added support for ECDSA signing methods
|
||||||
|
* Added support for RSA PSS signing methods (requires go v1.4)
|
||||||
|
|
||||||
|
#### 2.2.0
|
||||||
|
|
||||||
|
* Gracefully handle a `nil` `Keyfunc` being passed to `Parse`. Result will now be the parsed token and an error, instead of a panic.
|
||||||
|
|
||||||
|
#### 2.1.0
|
||||||
|
|
||||||
|
Backwards compatible API change that was missed in 2.0.0.
|
||||||
|
|
||||||
|
* The `SignedString` method on `Token` now takes `interface{}` instead of `[]byte`
|
||||||
|
|
||||||
|
#### 2.0.0
|
||||||
|
|
||||||
|
There were two major reasons for breaking backwards compatibility with this update. The first was a refactor required to expand the width of the RSA and HMAC-SHA signing implementations. There will likely be no required code changes to support this change.
|
||||||
|
|
||||||
|
The second update, while unfortunately requiring a small change in integration, is required to open up this library to other signing methods. Not all keys used for all signing methods have a single standard on-disk representation. Requiring `[]byte` as the type for all keys proved too limiting. Additionally, this implementation allows for pre-parsed tokens to be reused, which might matter in an application that parses a high volume of tokens with a small set of keys. Backwards compatibilty has been maintained for passing `[]byte` to the RSA signing methods, but they will also accept `*rsa.PublicKey` and `*rsa.PrivateKey`.
|
||||||
|
|
||||||
|
It is likely the only integration change required here will be to change `func(t *jwt.Token) ([]byte, error)` to `func(t *jwt.Token) (interface{}, error)` when calling `Parse`.
|
||||||
|
|
||||||
|
* **Compatibility Breaking Changes**
|
||||||
|
* `SigningMethodHS256` is now `*SigningMethodHMAC` instead of `type struct`
|
||||||
|
* `SigningMethodRS256` is now `*SigningMethodRSA` instead of `type struct`
|
||||||
|
* `KeyFunc` now returns `interface{}` instead of `[]byte`
|
||||||
|
* `SigningMethod.Sign` now takes `interface{}` instead of `[]byte` for the key
|
||||||
|
* `SigningMethod.Verify` now takes `interface{}` instead of `[]byte` for the key
|
||||||
|
* Renamed type `SigningMethodHS256` to `SigningMethodHMAC`. Specific sizes are now just instances of this type.
|
||||||
|
* Added public package global `SigningMethodHS256`
|
||||||
|
* Added public package global `SigningMethodHS384`
|
||||||
|
* Added public package global `SigningMethodHS512`
|
||||||
|
* Renamed type `SigningMethodRS256` to `SigningMethodRSA`. Specific sizes are now just instances of this type.
|
||||||
|
* Added public package global `SigningMethodRS256`
|
||||||
|
* Added public package global `SigningMethodRS384`
|
||||||
|
* Added public package global `SigningMethodRS512`
|
||||||
|
* Moved sample private key for HMAC tests from an inline value to a file on disk. Value is unchanged.
|
||||||
|
* Refactored the RSA implementation to be easier to read
|
||||||
|
* Exposed helper methods `ParseRSAPrivateKeyFromPEM` and `ParseRSAPublicKeyFromPEM`
|
||||||
|
|
||||||
|
#### 1.0.2
|
||||||
|
|
||||||
|
* Fixed bug in parsing public keys from certificates
|
||||||
|
* Added more tests around the parsing of keys for RS256
|
||||||
|
* Code refactoring in RS256 implementation. No functional changes
|
||||||
|
|
||||||
|
#### 1.0.1
|
||||||
|
|
||||||
|
* Fixed panic if RS256 signing method was passed an invalid key
|
||||||
|
|
||||||
|
#### 1.0.0
|
||||||
|
|
||||||
|
* First versioned release
|
||||||
|
* API stabilized
|
||||||
|
* Supports creating, signing, parsing, and validating JWT tokens
|
||||||
|
* Supports RS256 and HS256 signing methods
|
@ -0,0 +1,134 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// For a type to be a Claims object, it must just have a Valid method that determines
|
||||||
|
// if the token is invalid for any supported reason
|
||||||
|
type Claims interface {
|
||||||
|
Valid() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Structured version of Claims Section, as referenced at
|
||||||
|
// https://tools.ietf.org/html/rfc7519#section-4.1
|
||||||
|
// See examples for how to use this with your own claim types
|
||||||
|
type StandardClaims struct {
|
||||||
|
Audience string `json:"aud,omitempty"`
|
||||||
|
ExpiresAt int64 `json:"exp,omitempty"`
|
||||||
|
Id string `json:"jti,omitempty"`
|
||||||
|
IssuedAt int64 `json:"iat,omitempty"`
|
||||||
|
Issuer string `json:"iss,omitempty"`
|
||||||
|
NotBefore int64 `json:"nbf,omitempty"`
|
||||||
|
Subject string `json:"sub,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates time based claims "exp, iat, nbf".
|
||||||
|
// There is no accounting for clock skew.
|
||||||
|
// As well, if any of the above claims are not in the token, it will still
|
||||||
|
// be considered a valid claim.
|
||||||
|
func (c StandardClaims) Valid() error {
|
||||||
|
vErr := new(ValidationError)
|
||||||
|
now := TimeFunc().Unix()
|
||||||
|
|
||||||
|
// The claims below are optional, by default, so if they are set to the
|
||||||
|
// default value in Go, let's not fail the verification for them.
|
||||||
|
if c.VerifyExpiresAt(now, false) == false {
|
||||||
|
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
|
||||||
|
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
|
||||||
|
vErr.Errors |= ValidationErrorExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.VerifyIssuedAt(now, false) == false {
|
||||||
|
vErr.Inner = fmt.Errorf("Token used before issued")
|
||||||
|
vErr.Errors |= ValidationErrorIssuedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.VerifyNotBefore(now, false) == false {
|
||||||
|
vErr.Inner = fmt.Errorf("token is not valid yet")
|
||||||
|
vErr.Errors |= ValidationErrorNotValidYet
|
||||||
|
}
|
||||||
|
|
||||||
|
if vErr.valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return vErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the aud claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool {
|
||||||
|
return verifyAud(c.Audience, cmp, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the exp claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool {
|
||||||
|
return verifyExp(c.ExpiresAt, cmp, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the iat claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool {
|
||||||
|
return verifyIat(c.IssuedAt, cmp, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the iss claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool {
|
||||||
|
return verifyIss(c.Issuer, cmp, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the nbf claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
|
||||||
|
return verifyNbf(c.NotBefore, cmp, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- helpers
|
||||||
|
|
||||||
|
func verifyAud(aud string, cmp string, required bool) bool {
|
||||||
|
if aud == "" {
|
||||||
|
return !required
|
||||||
|
}
|
||||||
|
if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 {
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyExp(exp int64, now int64, required bool) bool {
|
||||||
|
if exp == 0 {
|
||||||
|
return !required
|
||||||
|
}
|
||||||
|
return now <= exp
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyIat(iat int64, now int64, required bool) bool {
|
||||||
|
if iat == 0 {
|
||||||
|
return !required
|
||||||
|
}
|
||||||
|
return now >= iat
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyIss(iss string, cmp string, required bool) bool {
|
||||||
|
if iss == "" {
|
||||||
|
return !required
|
||||||
|
}
|
||||||
|
if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 {
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyNbf(nbf int64, now int64, required bool) bool {
|
||||||
|
if nbf == 0 {
|
||||||
|
return !required
|
||||||
|
}
|
||||||
|
return now >= nbf
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
// Package jwt is a Go implementation of JSON Web Tokens: http://self-issued.info/docs/draft-jones-json-web-token.html
|
||||||
|
//
|
||||||
|
// See README.md for more info.
|
||||||
|
package jwt
|
@ -0,0 +1,148 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"math/big"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
|
||||||
|
ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implements the ECDSA family of signing methods signing methods
|
||||||
|
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
|
||||||
|
type SigningMethodECDSA struct {
|
||||||
|
Name string
|
||||||
|
Hash crypto.Hash
|
||||||
|
KeySize int
|
||||||
|
CurveBits int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specific instances for EC256 and company
|
||||||
|
var (
|
||||||
|
SigningMethodES256 *SigningMethodECDSA
|
||||||
|
SigningMethodES384 *SigningMethodECDSA
|
||||||
|
SigningMethodES512 *SigningMethodECDSA
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ES256
|
||||||
|
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
|
||||||
|
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodES256
|
||||||
|
})
|
||||||
|
|
||||||
|
// ES384
|
||||||
|
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
|
||||||
|
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodES384
|
||||||
|
})
|
||||||
|
|
||||||
|
// ES512
|
||||||
|
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
|
||||||
|
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodES512
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SigningMethodECDSA) Alg() string {
|
||||||
|
return m.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the Verify method from SigningMethod
|
||||||
|
// For this verify method, key must be an ecdsa.PublicKey struct
|
||||||
|
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Decode the signature
|
||||||
|
var sig []byte
|
||||||
|
if sig, err = DecodeSegment(signature); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the key
|
||||||
|
var ecdsaKey *ecdsa.PublicKey
|
||||||
|
switch k := key.(type) {
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
ecdsaKey = k
|
||||||
|
default:
|
||||||
|
return ErrInvalidKeyType
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sig) != 2*m.KeySize {
|
||||||
|
return ErrECDSAVerification
|
||||||
|
}
|
||||||
|
|
||||||
|
r := big.NewInt(0).SetBytes(sig[:m.KeySize])
|
||||||
|
s := big.NewInt(0).SetBytes(sig[m.KeySize:])
|
||||||
|
|
||||||
|
// Create hasher
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return ErrHashUnavailable
|
||||||
|
}
|
||||||
|
hasher := m.Hash.New()
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
|
||||||
|
// Verify the signature
|
||||||
|
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrECDSAVerification
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the Sign method from SigningMethod
|
||||||
|
// For this signing method, key must be an ecdsa.PrivateKey struct
|
||||||
|
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
|
||||||
|
// Get the key
|
||||||
|
var ecdsaKey *ecdsa.PrivateKey
|
||||||
|
switch k := key.(type) {
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
ecdsaKey = k
|
||||||
|
default:
|
||||||
|
return "", ErrInvalidKeyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the hasher
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return "", ErrHashUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := m.Hash.New()
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
|
||||||
|
// Sign the string and return r, s
|
||||||
|
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
|
||||||
|
curveBits := ecdsaKey.Curve.Params().BitSize
|
||||||
|
|
||||||
|
if m.CurveBits != curveBits {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
|
||||||
|
keyBytes := curveBits / 8
|
||||||
|
if curveBits%8 > 0 {
|
||||||
|
keyBytes += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// We serialize the outpus (r and s) into big-endian byte arrays and pad
|
||||||
|
// them with zeros on the left to make sure the sizes work out. Both arrays
|
||||||
|
// must be keyBytes long, and the output must be 2*keyBytes long.
|
||||||
|
rBytes := r.Bytes()
|
||||||
|
rBytesPadded := make([]byte, keyBytes)
|
||||||
|
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||||
|
|
||||||
|
sBytes := s.Bytes()
|
||||||
|
sBytesPadded := make([]byte, keyBytes)
|
||||||
|
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||||
|
|
||||||
|
out := append(rBytesPadded, sBytesPadded...)
|
||||||
|
|
||||||
|
return EncodeSegment(out), nil
|
||||||
|
} else {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,67 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotECPublicKey = errors.New("Key is not a valid ECDSA public key")
|
||||||
|
ErrNotECPrivateKey = errors.New("Key is not a valid ECDSA private key")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse PEM encoded Elliptic Curve Private Key Structure
|
||||||
|
func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse PEM block
|
||||||
|
var block *pem.Block
|
||||||
|
if block, _ = pem.Decode(key); block == nil {
|
||||||
|
return nil, ErrKeyMustBePEMEncoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the key
|
||||||
|
var parsedKey interface{}
|
||||||
|
if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkey *ecdsa.PrivateKey
|
||||||
|
var ok bool
|
||||||
|
if pkey, ok = parsedKey.(*ecdsa.PrivateKey); !ok {
|
||||||
|
return nil, ErrNotECPrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PEM encoded PKCS1 or PKCS8 public key
|
||||||
|
func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse PEM block
|
||||||
|
var block *pem.Block
|
||||||
|
if block, _ = pem.Decode(key); block == nil {
|
||||||
|
return nil, ErrKeyMustBePEMEncoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the key
|
||||||
|
var parsedKey interface{}
|
||||||
|
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||||
|
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||||
|
parsedKey = cert.PublicKey
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkey *ecdsa.PublicKey
|
||||||
|
var ok bool
|
||||||
|
if pkey, ok = parsedKey.(*ecdsa.PublicKey); !ok {
|
||||||
|
return nil, ErrNotECPublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkey, nil
|
||||||
|
}
|
@ -0,0 +1,59 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error constants
|
||||||
|
var (
|
||||||
|
ErrInvalidKey = errors.New("key is invalid")
|
||||||
|
ErrInvalidKeyType = errors.New("key is of invalid type")
|
||||||
|
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
|
||||||
|
)
|
||||||
|
|
||||||
|
// The errors that might occur when parsing and validating a token
|
||||||
|
const (
|
||||||
|
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed
|
||||||
|
ValidationErrorUnverifiable // Token could not be verified because of signing problems
|
||||||
|
ValidationErrorSignatureInvalid // Signature validation failed
|
||||||
|
|
||||||
|
// Standard Claim validation errors
|
||||||
|
ValidationErrorAudience // AUD validation failed
|
||||||
|
ValidationErrorExpired // EXP validation failed
|
||||||
|
ValidationErrorIssuedAt // IAT validation failed
|
||||||
|
ValidationErrorIssuer // ISS validation failed
|
||||||
|
ValidationErrorNotValidYet // NBF validation failed
|
||||||
|
ValidationErrorId // JTI validation failed
|
||||||
|
ValidationErrorClaimsInvalid // Generic claims validation error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper for constructing a ValidationError with a string error message
|
||||||
|
func NewValidationError(errorText string, errorFlags uint32) *ValidationError {
|
||||||
|
return &ValidationError{
|
||||||
|
text: errorText,
|
||||||
|
Errors: errorFlags,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error from Parse if token is not valid
|
||||||
|
type ValidationError struct {
|
||||||
|
Inner error // stores the error returned by external dependencies, i.e.: KeyFunc
|
||||||
|
Errors uint32 // bitfield. see ValidationError... constants
|
||||||
|
text string // errors that do not have a valid error just have text
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validation error is an error type
|
||||||
|
func (e ValidationError) Error() string {
|
||||||
|
if e.Inner != nil {
|
||||||
|
return e.Inner.Error()
|
||||||
|
} else if e.text != "" {
|
||||||
|
return e.text
|
||||||
|
} else {
|
||||||
|
return "token is invalid"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No errors
|
||||||
|
func (e *ValidationError) valid() bool {
|
||||||
|
return e.Errors == 0
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/hmac"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implements the HMAC-SHA family of signing methods signing methods
|
||||||
|
// Expects key type of []byte for both signing and validation
|
||||||
|
type SigningMethodHMAC struct {
|
||||||
|
Name string
|
||||||
|
Hash crypto.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specific instances for HS256 and company
|
||||||
|
var (
|
||||||
|
SigningMethodHS256 *SigningMethodHMAC
|
||||||
|
SigningMethodHS384 *SigningMethodHMAC
|
||||||
|
SigningMethodHS512 *SigningMethodHMAC
|
||||||
|
ErrSignatureInvalid = errors.New("signature is invalid")
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// HS256
|
||||||
|
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
|
||||||
|
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodHS256
|
||||||
|
})
|
||||||
|
|
||||||
|
// HS384
|
||||||
|
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
|
||||||
|
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodHS384
|
||||||
|
})
|
||||||
|
|
||||||
|
// HS512
|
||||||
|
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
|
||||||
|
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodHS512
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SigningMethodHMAC) Alg() string {
|
||||||
|
return m.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the signature of HSXXX tokens. Returns nil if the signature is valid.
|
||||||
|
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
|
||||||
|
// Verify the key is the right type
|
||||||
|
keyBytes, ok := key.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidKeyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode signature, for comparison
|
||||||
|
sig, err := DecodeSegment(signature)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can we use the specified hashing method?
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return ErrHashUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
// This signing method is symmetric, so we validate the signature
|
||||||
|
// by reproducing the signature from the signing string and key, then
|
||||||
|
// comparing that against the provided signature.
|
||||||
|
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
if !hmac.Equal(sig, hasher.Sum(nil)) {
|
||||||
|
return ErrSignatureInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// No validation errors. Signature is good.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the Sign method from SigningMethod for this signing method.
|
||||||
|
// Key must be []byte
|
||||||
|
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
|
||||||
|
if keyBytes, ok := key.([]byte); ok {
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return "", ErrHashUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
|
||||||
|
return EncodeSegment(hasher.Sum(nil)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", ErrInvalidKeyType
|
||||||
|
}
|
@ -0,0 +1,94 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
// "fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claims type that uses the map[string]interface{} for JSON decoding
|
||||||
|
// This is the default claims type if you don't supply one
|
||||||
|
type MapClaims map[string]interface{}
|
||||||
|
|
||||||
|
// Compares the aud claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
|
||||||
|
aud, _ := m["aud"].(string)
|
||||||
|
return verifyAud(aud, cmp, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the exp claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
|
||||||
|
switch exp := m["exp"].(type) {
|
||||||
|
case float64:
|
||||||
|
return verifyExp(int64(exp), cmp, req)
|
||||||
|
case json.Number:
|
||||||
|
v, _ := exp.Int64()
|
||||||
|
return verifyExp(v, cmp, req)
|
||||||
|
}
|
||||||
|
return req == false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the iat claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
|
||||||
|
switch iat := m["iat"].(type) {
|
||||||
|
case float64:
|
||||||
|
return verifyIat(int64(iat), cmp, req)
|
||||||
|
case json.Number:
|
||||||
|
v, _ := iat.Int64()
|
||||||
|
return verifyIat(v, cmp, req)
|
||||||
|
}
|
||||||
|
return req == false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the iss claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
|
||||||
|
iss, _ := m["iss"].(string)
|
||||||
|
return verifyIss(iss, cmp, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the nbf claim against cmp.
|
||||||
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
|
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
|
||||||
|
switch nbf := m["nbf"].(type) {
|
||||||
|
case float64:
|
||||||
|
return verifyNbf(int64(nbf), cmp, req)
|
||||||
|
case json.Number:
|
||||||
|
v, _ := nbf.Int64()
|
||||||
|
return verifyNbf(v, cmp, req)
|
||||||
|
}
|
||||||
|
return req == false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates time based claims "exp, iat, nbf".
|
||||||
|
// There is no accounting for clock skew.
|
||||||
|
// As well, if any of the above claims are not in the token, it will still
|
||||||
|
// be considered a valid claim.
|
||||||
|
func (m MapClaims) Valid() error {
|
||||||
|
vErr := new(ValidationError)
|
||||||
|
now := TimeFunc().Unix()
|
||||||
|
|
||||||
|
if m.VerifyExpiresAt(now, false) == false {
|
||||||
|
vErr.Inner = errors.New("Token is expired")
|
||||||
|
vErr.Errors |= ValidationErrorExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.VerifyIssuedAt(now, false) == false {
|
||||||
|
vErr.Inner = errors.New("Token used before issued")
|
||||||
|
vErr.Errors |= ValidationErrorIssuedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.VerifyNotBefore(now, false) == false {
|
||||||
|
vErr.Inner = errors.New("Token is not valid yet")
|
||||||
|
vErr.Errors |= ValidationErrorNotValidYet
|
||||||
|
}
|
||||||
|
|
||||||
|
if vErr.valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return vErr
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
// Implements the none signing method. This is required by the spec
|
||||||
|
// but you probably should never use it.
|
||||||
|
var SigningMethodNone *signingMethodNone
|
||||||
|
|
||||||
|
const UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed"
|
||||||
|
|
||||||
|
var NoneSignatureTypeDisallowedError error
|
||||||
|
|
||||||
|
type signingMethodNone struct{}
|
||||||
|
type unsafeNoneMagicConstant string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
SigningMethodNone = &signingMethodNone{}
|
||||||
|
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
|
||||||
|
|
||||||
|
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodNone
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *signingMethodNone) Alg() string {
|
||||||
|
return "none"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
|
||||||
|
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
|
||||||
|
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
|
||||||
|
// accepting 'none' signing method
|
||||||
|
if _, ok := key.(unsafeNoneMagicConstant); !ok {
|
||||||
|
return NoneSignatureTypeDisallowedError
|
||||||
|
}
|
||||||
|
// If signing method is none, signature must be an empty string
|
||||||
|
if signature != "" {
|
||||||
|
return NewValidationError(
|
||||||
|
"'none' signing method with non-empty signature",
|
||||||
|
ValidationErrorSignatureInvalid,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept 'none' signing method.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
|
||||||
|
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
|
||||||
|
if _, ok := key.(unsafeNoneMagicConstant); ok {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", NoneSignatureTypeDisallowedError
|
||||||
|
}
|
@ -0,0 +1,148 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Parser struct {
|
||||||
|
ValidMethods []string // If populated, only these methods will be considered valid
|
||||||
|
UseJSONNumber bool // Use JSON Number format in JSON decoder
|
||||||
|
SkipClaimsValidation bool // Skip claims validation during token parsing
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse, validate, and return a token.
|
||||||
|
// keyFunc will receive the parsed token and should return the key for validating.
|
||||||
|
// If everything is kosher, err will be nil
|
||||||
|
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
|
||||||
|
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
|
||||||
|
token, parts, err := p.ParseUnverified(tokenString, claims)
|
||||||
|
if err != nil {
|
||||||
|
return token, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify signing method is in the required set
|
||||||
|
if p.ValidMethods != nil {
|
||||||
|
var signingMethodValid = false
|
||||||
|
var alg = token.Method.Alg()
|
||||||
|
for _, m := range p.ValidMethods {
|
||||||
|
if m == alg {
|
||||||
|
signingMethodValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !signingMethodValid {
|
||||||
|
// signing method is not in the listed set
|
||||||
|
return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup key
|
||||||
|
var key interface{}
|
||||||
|
if keyFunc == nil {
|
||||||
|
// keyFunc was not provided. short circuiting validation
|
||||||
|
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
|
||||||
|
}
|
||||||
|
if key, err = keyFunc(token); err != nil {
|
||||||
|
// keyFunc returned an error
|
||||||
|
if ve, ok := err.(*ValidationError); ok {
|
||||||
|
return token, ve
|
||||||
|
}
|
||||||
|
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
|
||||||
|
}
|
||||||
|
|
||||||
|
vErr := &ValidationError{}
|
||||||
|
|
||||||
|
// Validate Claims
|
||||||
|
if !p.SkipClaimsValidation {
|
||||||
|
if err := token.Claims.Valid(); err != nil {
|
||||||
|
|
||||||
|
// If the Claims Valid returned an error, check if it is a validation error,
|
||||||
|
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
|
||||||
|
if e, ok := err.(*ValidationError); !ok {
|
||||||
|
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
|
||||||
|
} else {
|
||||||
|
vErr = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform validation
|
||||||
|
token.Signature = parts[2]
|
||||||
|
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
|
||||||
|
vErr.Inner = err
|
||||||
|
vErr.Errors |= ValidationErrorSignatureInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
if vErr.valid() {
|
||||||
|
token.Valid = true
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, vErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// WARNING: Don't use this method unless you know what you're doing
|
||||||
|
//
|
||||||
|
// This method parses the token but doesn't validate the signature. It's only
|
||||||
|
// ever useful in cases where you know the signature is valid (because it has
|
||||||
|
// been checked previously in the stack) and you want to extract values from
|
||||||
|
// it.
|
||||||
|
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
|
||||||
|
parts = strings.Split(tokenString, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
|
||||||
|
}
|
||||||
|
|
||||||
|
token = &Token{Raw: tokenString}
|
||||||
|
|
||||||
|
// parse Header
|
||||||
|
var headerBytes []byte
|
||||||
|
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
|
||||||
|
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
|
||||||
|
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
|
||||||
|
}
|
||||||
|
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
|
||||||
|
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse Claims
|
||||||
|
var claimBytes []byte
|
||||||
|
token.Claims = claims
|
||||||
|
|
||||||
|
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
|
||||||
|
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||||
|
}
|
||||||
|
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
|
||||||
|
if p.UseJSONNumber {
|
||||||
|
dec.UseNumber()
|
||||||
|
}
|
||||||
|
// JSON Decode. Special case for map type to avoid weird pointer behavior
|
||||||
|
if c, ok := token.Claims.(MapClaims); ok {
|
||||||
|
err = dec.Decode(&c)
|
||||||
|
} else {
|
||||||
|
err = dec.Decode(&claims)
|
||||||
|
}
|
||||||
|
// Handle decode error
|
||||||
|
if err != nil {
|
||||||
|
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup signature method
|
||||||
|
if method, ok := token.Header["alg"].(string); ok {
|
||||||
|
if token.Method = GetSigningMethod(method); token.Method == nil {
|
||||||
|
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, parts, nil
|
||||||
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implements the RSA family of signing methods signing methods
|
||||||
|
// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation
|
||||||
|
type SigningMethodRSA struct {
|
||||||
|
Name string
|
||||||
|
Hash crypto.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specific instances for RS256 and company
|
||||||
|
var (
|
||||||
|
SigningMethodRS256 *SigningMethodRSA
|
||||||
|
SigningMethodRS384 *SigningMethodRSA
|
||||||
|
SigningMethodRS512 *SigningMethodRSA
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// RS256
|
||||||
|
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
|
||||||
|
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodRS256
|
||||||
|
})
|
||||||
|
|
||||||
|
// RS384
|
||||||
|
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
|
||||||
|
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodRS384
|
||||||
|
})
|
||||||
|
|
||||||
|
// RS512
|
||||||
|
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
|
||||||
|
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodRS512
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SigningMethodRSA) Alg() string {
|
||||||
|
return m.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the Verify method from SigningMethod
|
||||||
|
// For this signing method, must be an *rsa.PublicKey structure.
|
||||||
|
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Decode the signature
|
||||||
|
var sig []byte
|
||||||
|
if sig, err = DecodeSegment(signature); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var rsaKey *rsa.PublicKey
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
if rsaKey, ok = key.(*rsa.PublicKey); !ok {
|
||||||
|
return ErrInvalidKeyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create hasher
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return ErrHashUnavailable
|
||||||
|
}
|
||||||
|
hasher := m.Hash.New()
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
|
||||||
|
// Verify the signature
|
||||||
|
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the Sign method from SigningMethod
|
||||||
|
// For this signing method, must be an *rsa.PrivateKey structure.
|
||||||
|
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
|
||||||
|
var rsaKey *rsa.PrivateKey
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// Validate type of key
|
||||||
|
if rsaKey, ok = key.(*rsa.PrivateKey); !ok {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the hasher
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return "", ErrHashUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := m.Hash.New()
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
|
||||||
|
// Sign the string and return the encoded bytes
|
||||||
|
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
|
||||||
|
return EncodeSegment(sigBytes), nil
|
||||||
|
} else {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,126 @@
|
|||||||
|
// +build go1.4
|
||||||
|
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implements the RSAPSS family of signing methods signing methods
|
||||||
|
type SigningMethodRSAPSS struct {
|
||||||
|
*SigningMethodRSA
|
||||||
|
Options *rsa.PSSOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specific instances for RS/PS and company
|
||||||
|
var (
|
||||||
|
SigningMethodPS256 *SigningMethodRSAPSS
|
||||||
|
SigningMethodPS384 *SigningMethodRSAPSS
|
||||||
|
SigningMethodPS512 *SigningMethodRSAPSS
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// PS256
|
||||||
|
SigningMethodPS256 = &SigningMethodRSAPSS{
|
||||||
|
&SigningMethodRSA{
|
||||||
|
Name: "PS256",
|
||||||
|
Hash: crypto.SHA256,
|
||||||
|
},
|
||||||
|
&rsa.PSSOptions{
|
||||||
|
SaltLength: rsa.PSSSaltLengthAuto,
|
||||||
|
Hash: crypto.SHA256,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodPS256
|
||||||
|
})
|
||||||
|
|
||||||
|
// PS384
|
||||||
|
SigningMethodPS384 = &SigningMethodRSAPSS{
|
||||||
|
&SigningMethodRSA{
|
||||||
|
Name: "PS384",
|
||||||
|
Hash: crypto.SHA384,
|
||||||
|
},
|
||||||
|
&rsa.PSSOptions{
|
||||||
|
SaltLength: rsa.PSSSaltLengthAuto,
|
||||||
|
Hash: crypto.SHA384,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodPS384
|
||||||
|
})
|
||||||
|
|
||||||
|
// PS512
|
||||||
|
SigningMethodPS512 = &SigningMethodRSAPSS{
|
||||||
|
&SigningMethodRSA{
|
||||||
|
Name: "PS512",
|
||||||
|
Hash: crypto.SHA512,
|
||||||
|
},
|
||||||
|
&rsa.PSSOptions{
|
||||||
|
SaltLength: rsa.PSSSaltLengthAuto,
|
||||||
|
Hash: crypto.SHA512,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
|
||||||
|
return SigningMethodPS512
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the Verify method from SigningMethod
|
||||||
|
// For this verify method, key must be an rsa.PublicKey struct
|
||||||
|
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Decode the signature
|
||||||
|
var sig []byte
|
||||||
|
if sig, err = DecodeSegment(signature); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var rsaKey *rsa.PublicKey
|
||||||
|
switch k := key.(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
rsaKey = k
|
||||||
|
default:
|
||||||
|
return ErrInvalidKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create hasher
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return ErrHashUnavailable
|
||||||
|
}
|
||||||
|
hasher := m.Hash.New()
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
|
||||||
|
return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the Sign method from SigningMethod
|
||||||
|
// For this signing method, key must be an rsa.PrivateKey struct
|
||||||
|
func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) (string, error) {
|
||||||
|
var rsaKey *rsa.PrivateKey
|
||||||
|
|
||||||
|
switch k := key.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
rsaKey = k
|
||||||
|
default:
|
||||||
|
return "", ErrInvalidKeyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the hasher
|
||||||
|
if !m.Hash.Available() {
|
||||||
|
return "", ErrHashUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := m.Hash.New()
|
||||||
|
hasher.Write([]byte(signingString))
|
||||||
|
|
||||||
|
// Sign the string and return the encoded bytes
|
||||||
|
if sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil), m.Options); err == nil {
|
||||||
|
return EncodeSegment(sigBytes), nil
|
||||||
|
} else {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key")
|
||||||
|
ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key")
|
||||||
|
ErrNotRSAPublicKey = errors.New("Key is not a valid RSA public key")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse PEM encoded PKCS1 or PKCS8 private key
|
||||||
|
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse PEM block
|
||||||
|
var block *pem.Block
|
||||||
|
if block, _ = pem.Decode(key); block == nil {
|
||||||
|
return nil, ErrKeyMustBePEMEncoded
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsedKey interface{}
|
||||||
|
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||||
|
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkey *rsa.PrivateKey
|
||||||
|
var ok bool
|
||||||
|
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
|
||||||
|
return nil, ErrNotRSAPrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PEM encoded PKCS1 or PKCS8 private key protected with password
|
||||||
|
func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.PrivateKey, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse PEM block
|
||||||
|
var block *pem.Block
|
||||||
|
if block, _ = pem.Decode(key); block == nil {
|
||||||
|
return nil, ErrKeyMustBePEMEncoded
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsedKey interface{}
|
||||||
|
|
||||||
|
var blockDecrypted []byte
|
||||||
|
if blockDecrypted, err = x509.DecryptPEMBlock(block, []byte(password)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
|
||||||
|
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkey *rsa.PrivateKey
|
||||||
|
var ok bool
|
||||||
|
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
|
||||||
|
return nil, ErrNotRSAPrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PEM encoded PKCS1 or PKCS8 public key
|
||||||
|
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse PEM block
|
||||||
|
var block *pem.Block
|
||||||
|
if block, _ = pem.Decode(key); block == nil {
|
||||||
|
return nil, ErrKeyMustBePEMEncoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the key
|
||||||
|
var parsedKey interface{}
|
||||||
|
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||||
|
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||||
|
parsedKey = cert.PublicKey
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkey *rsa.PublicKey
|
||||||
|
var ok bool
|
||||||
|
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
|
||||||
|
return nil, ErrNotRSAPublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkey, nil
|
||||||
|
}
|
@ -0,0 +1,35 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var signingMethods = map[string]func() SigningMethod{}
|
||||||
|
var signingMethodLock = new(sync.RWMutex)
|
||||||
|
|
||||||
|
// Implement SigningMethod to add new methods for signing or verifying tokens.
|
||||||
|
type SigningMethod interface {
|
||||||
|
Verify(signingString, signature string, key interface{}) error // Returns nil if signature is valid
|
||||||
|
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
|
||||||
|
Alg() string // returns the alg identifier for this method (example: 'HS256')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the "alg" name and a factory function for signing method.
|
||||||
|
// This is typically done during init() in the method's implementation
|
||||||
|
func RegisterSigningMethod(alg string, f func() SigningMethod) {
|
||||||
|
signingMethodLock.Lock()
|
||||||
|
defer signingMethodLock.Unlock()
|
||||||
|
|
||||||
|
signingMethods[alg] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a signing method from an "alg" string
|
||||||
|
func GetSigningMethod(alg string) (method SigningMethod) {
|
||||||
|
signingMethodLock.RLock()
|
||||||
|
defer signingMethodLock.RUnlock()
|
||||||
|
|
||||||
|
if methodF, ok := signingMethods[alg]; ok {
|
||||||
|
method = methodF()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@ -0,0 +1,108 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TimeFunc provides the current time when parsing token to validate "exp" claim (expiration time).
|
||||||
|
// You can override it to use another time value. This is useful for testing or if your
|
||||||
|
// server uses a different time zone than your tokens.
|
||||||
|
var TimeFunc = time.Now
|
||||||
|
|
||||||
|
// Parse methods use this callback function to supply
|
||||||
|
// the key for verification. The function receives the parsed,
|
||||||
|
// but unverified Token. This allows you to use properties in the
|
||||||
|
// Header of the token (such as `kid`) to identify which key to use.
|
||||||
|
type Keyfunc func(*Token) (interface{}, error)
|
||||||
|
|
||||||
|
// A JWT Token. Different fields will be used depending on whether you're
|
||||||
|
// creating or parsing/verifying a token.
|
||||||
|
type Token struct {
|
||||||
|
Raw string // The raw token. Populated when you Parse a token
|
||||||
|
Method SigningMethod // The signing method used or to be used
|
||||||
|
Header map[string]interface{} // The first segment of the token
|
||||||
|
Claims Claims // The second segment of the token
|
||||||
|
Signature string // The third segment of the token. Populated when you Parse a token
|
||||||
|
Valid bool // Is the token valid? Populated when you Parse/Verify a token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new Token. Takes a signing method
|
||||||
|
func New(method SigningMethod) *Token {
|
||||||
|
return NewWithClaims(method, MapClaims{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWithClaims(method SigningMethod, claims Claims) *Token {
|
||||||
|
return &Token{
|
||||||
|
Header: map[string]interface{}{
|
||||||
|
"typ": "JWT",
|
||||||
|
"alg": method.Alg(),
|
||||||
|
},
|
||||||
|
Claims: claims,
|
||||||
|
Method: method,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the complete, signed token
|
||||||
|
func (t *Token) SignedString(key interface{}) (string, error) {
|
||||||
|
var sig, sstr string
|
||||||
|
var err error
|
||||||
|
if sstr, err = t.SigningString(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if sig, err = t.Method.Sign(sstr, key); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return strings.Join([]string{sstr, sig}, "."), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the signing string. This is the
|
||||||
|
// most expensive part of the whole deal. Unless you
|
||||||
|
// need this for something special, just go straight for
|
||||||
|
// the SignedString.
|
||||||
|
func (t *Token) SigningString() (string, error) {
|
||||||
|
var err error
|
||||||
|
parts := make([]string, 2)
|
||||||
|
for i, _ := range parts {
|
||||||
|
var jsonValue []byte
|
||||||
|
if i == 0 {
|
||||||
|
if jsonValue, err = json.Marshal(t.Header); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if jsonValue, err = json.Marshal(t.Claims); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parts[i] = EncodeSegment(jsonValue)
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "."), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse, validate, and return a token.
|
||||||
|
// keyFunc will receive the parsed token and should return the key for validating.
|
||||||
|
// If everything is kosher, err will be nil
|
||||||
|
func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
|
||||||
|
return new(Parser).Parse(tokenString, keyFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
|
||||||
|
return new(Parser).ParseWithClaims(tokenString, claims, keyFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode JWT specific base64url encoding with padding stripped
|
||||||
|
func EncodeSegment(seg []byte) string {
|
||||||
|
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode JWT specific base64url encoding with padding stripped
|
||||||
|
func DecodeSegment(seg string) ([]byte, error) {
|
||||||
|
if l := len(seg) % 4; l > 0 {
|
||||||
|
seg += strings.Repeat("=", 4-l)
|
||||||
|
}
|
||||||
|
|
||||||
|
return base64.URLEncoding.DecodeString(seg)
|
||||||
|
}
|
@ -0,0 +1,5 @@
|
|||||||
|
gocql-fuzz
|
||||||
|
fuzz-corpus
|
||||||
|
fuzz-work
|
||||||
|
gocql.test
|
||||||
|
.idea
|
@ -0,0 +1,49 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
sudo: required
|
||||||
|
dist: trusty
|
||||||
|
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- $HOME/.ccm/repository
|
||||||
|
- $HOME/.local/lib/python2.7
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
fast_finish: true
|
||||||
|
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- master
|
||||||
|
|
||||||
|
env:
|
||||||
|
global:
|
||||||
|
- GOMAXPROCS=2
|
||||||
|
matrix:
|
||||||
|
- CASS=2.1.21
|
||||||
|
AUTH=true
|
||||||
|
- CASS=2.2.14
|
||||||
|
AUTH=true
|
||||||
|
- CASS=2.2.14
|
||||||
|
AUTH=false
|
||||||
|
- CASS=3.0.18
|
||||||
|
AUTH=false
|
||||||
|
- CASS=3.11.4
|
||||||
|
AUTH=false
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.12.x
|
||||||
|
- 1.13.x
|
||||||
|
|
||||||
|
install:
|
||||||
|
- ./install_test_deps.sh $TRAVIS_REPO_SLUG
|
||||||
|
- cd ../..
|
||||||
|
- cd gocql/gocql
|
||||||
|
- go get .
|
||||||
|
|
||||||
|
script:
|
||||||
|
- set -e
|
||||||
|
- PATH=$PATH:$HOME/.local/bin bash integration.sh $CASS $AUTH
|
||||||
|
- go vet .
|
||||||
|
|
||||||
|
notifications:
|
||||||
|
- email: false
|
@ -0,0 +1,114 @@
|
|||||||
|
# This source file refers to The gocql Authors for copyright purposes.
|
||||||
|
|
||||||
|
Christoph Hack <christoph@tux21b.org>
|
||||||
|
Jonathan Rudenberg <jonathan@titanous.com>
|
||||||
|
Thorsten von Eicken <tve@rightscale.com>
|
||||||
|
Matt Robenolt <mattr@disqus.com>
|
||||||
|
Phillip Couto <phillip.couto@stemstudios.com>
|
||||||
|
Niklas Korz <korz.niklask@gmail.com>
|
||||||
|
Nimi Wariboko Jr <nimi@channelmeter.com>
|
||||||
|
Ghais Issa <ghais.issa@gmail.com>
|
||||||
|
Sasha Klizhentas <klizhentas@gmail.com>
|
||||||
|
Konstantin Cherkasov <k.cherkasoff@gmail.com>
|
||||||
|
Ben Hood <0x6e6562@gmail.com>
|
||||||
|
Pete Hopkins <phopkins@gmail.com>
|
||||||
|
Chris Bannister <c.bannister@gmail.com>
|
||||||
|
Maxim Bublis <b@codemonkey.ru>
|
||||||
|
Alex Zorin <git@zor.io>
|
||||||
|
Kasper Middelboe Petersen <me@phant.dk>
|
||||||
|
Harpreet Sawhney <harpreet.sawhney@gmail.com>
|
||||||
|
Charlie Andrews <charlieandrews.cwa@gmail.com>
|
||||||
|
Stanislavs Koikovs <stanislavs.koikovs@gmail.com>
|
||||||
|
Dan Forest <bonjour@dan.tf>
|
||||||
|
Miguel Serrano <miguelvps@gmail.com>
|
||||||
|
Stefan Radomski <gibheer@zero-knowledge.org>
|
||||||
|
Josh Wright <jshwright@gmail.com>
|
||||||
|
Jacob Rhoden <jacob.rhoden@gmail.com>
|
||||||
|
Ben Frye <benfrye@gmail.com>
|
||||||
|
Fred McCann <fred@sharpnoodles.com>
|
||||||
|
Dan Simmons <dan@simmons.io>
|
||||||
|
Muir Manders <muir@retailnext.net>
|
||||||
|
Sankar P <sankar.curiosity@gmail.com>
|
||||||
|
Julien Da Silva <julien.dasilva@gmail.com>
|
||||||
|
Dan Kennedy <daniel@firstcs.co.uk>
|
||||||
|
Nick Dhupia<nick.dhupia@gmail.com>
|
||||||
|
Yasuharu Goto <matope.ono@gmail.com>
|
||||||
|
Jeremy Schlatter <jeremy.schlatter@gmail.com>
|
||||||
|
Matthias Kadenbach <matthias.kadenbach@gmail.com>
|
||||||
|
Dean Elbaz <elbaz.dean@gmail.com>
|
||||||
|
Mike Berman <evencode@gmail.com>
|
||||||
|
Dmitriy Fedorenko <c0va23@gmail.com>
|
||||||
|
Zach Marcantel <zmarcantel@gmail.com>
|
||||||
|
James Maloney <jamessagan@gmail.com>
|
||||||
|
Ashwin Purohit <purohit@gmail.com>
|
||||||
|
Dan Kinder <dkinder.is.me@gmail.com>
|
||||||
|
Oliver Beattie <oliver@obeattie.com>
|
||||||
|
Justin Corpron <jncorpron@gmail.com>
|
||||||
|
Miles Delahunty <miles.delahunty@gmail.com>
|
||||||
|
Zach Badgett <zach.badgett@gmail.com>
|
||||||
|
Maciek Sakrejda <maciek@heroku.com>
|
||||||
|
Jeff Mitchell <jeffrey.mitchell@gmail.com>
|
||||||
|
Baptiste Fontaine <b@ptistefontaine.fr>
|
||||||
|
Matt Heath <matt@mattheath.com>
|
||||||
|
Jamie Cuthill <jamie.cuthill@gmail.com>
|
||||||
|
Adrian Casajus <adriancasajus@gmail.com>
|
||||||
|
John Weldon <johnweldon4@gmail.com>
|
||||||
|
Adrien Bustany <adrien@bustany.org>
|
||||||
|
Andrey Smirnov <smirnov.andrey@gmail.com>
|
||||||
|
Adam Weiner <adamsweiner@gmail.com>
|
||||||
|
Daniel Cannon <daniel@danielcannon.co.uk>
|
||||||
|
Johnny Bergström <johnny@joonix.se>
|
||||||
|
Adriano Orioli <orioli.adriano@gmail.com>
|
||||||
|
Claudiu Raveica <claudiu.raveica@gmail.com>
|
||||||
|
Artem Chernyshev <artem.0xD2@gmail.com>
|
||||||
|
Ference Fu <fym201@msn.com>
|
||||||
|
LOVOO <opensource@lovoo.com>
|
||||||
|
nikandfor <nikandfor@gmail.com>
|
||||||
|
Anthony Woods <awoods@raintank.io>
|
||||||
|
Alexander Inozemtsev <alexander.inozemtsev@gmail.com>
|
||||||
|
Rob McColl <rob@robmccoll.com>; <rmccoll@ionicsecurity.com>
|
||||||
|
Viktor Tönköl <viktor.toenkoel@motionlogic.de>
|
||||||
|
Ian Lozinski <ian.lozinski@gmail.com>
|
||||||
|
Michael Highstead <highstead@gmail.com>
|
||||||
|
Sarah Brown <esbie.is@gmail.com>
|
||||||
|
Caleb Doxsey <caleb@datadoghq.com>
|
||||||
|
Frederic Hemery <frederic.hemery@datadoghq.com>
|
||||||
|
Pekka Enberg <penberg@scylladb.com>
|
||||||
|
Mark M <m.mim95@gmail.com>
|
||||||
|
Bartosz Burclaf <burclaf@gmail.com>
|
||||||
|
Marcus King <marcusking01@gmail.com>
|
||||||
|
Andrew de Andrade <andrew@deandrade.com.br>
|
||||||
|
Robert Nix <robert@nicerobot.org>
|
||||||
|
Nathan Youngman <git@nathany.com>
|
||||||
|
Charles Law <charles.law@gmail.com>; <claw@conduce.com>
|
||||||
|
Nathan Davies <nathanjamesdavies@gmail.com>
|
||||||
|
Bo Blanton <bo.blanton@gmail.com>
|
||||||
|
Vincent Rischmann <me@vrischmann.me>
|
||||||
|
Jesse Claven <jesse.claven@gmail.com>
|
||||||
|
Derrick Wippler <thrawn01@gmail.com>
|
||||||
|
Leigh McCulloch <leigh@leighmcculloch.com>
|
||||||
|
Ron Kuris <swcafe@gmail.com>
|
||||||
|
Raphael Gavache <raphael.gavache@gmail.com>
|
||||||
|
Yasser Abdolmaleki <yasser@yasser.ca>
|
||||||
|
Krishnanand Thommandra <devtkrishna@gmail.com>
|
||||||
|
Blake Atkinson <me@blakeatkinson.com>
|
||||||
|
Dharmendra Parsaila <d4dharmu@gmail.com>
|
||||||
|
Nayef Ghattas <nayef.ghattas@datadoghq.com>
|
||||||
|
Michał Matczuk <mmatczuk@gmail.com>
|
||||||
|
Ben Krebsbach <ben.krebsbach@gmail.com>
|
||||||
|
Vivian Mathews <vivian.mathews.3@gmail.com>
|
||||||
|
Sascha Steinbiss <satta@debian.org>
|
||||||
|
Seth Rosenblum <seth.t.rosenblum@gmail.com>
|
||||||
|
Javier Zunzunegui <javier.zunzunegui.b@gmail.com>
|
||||||
|
Luke Hines <lukehines@protonmail.com>
|
||||||
|
Zhixin Wen <john.wenzhixin@hotmail.com>
|
||||||
|
Chang Liu <changliu.it@gmail.com>
|
||||||
|
Ingo Oeser <nightlyone@gmail.com>
|
||||||
|
Luke Hines <lukehines@protonmail.com>
|
||||||
|
Jacob Greenleaf <jacob@jacobgreenleaf.com>
|
||||||
|
Alex Lourie <alex@instaclustr.com>; <djay.il@gmail.com>
|
||||||
|
Marco Cadetg <cadetg@gmail.com>
|
||||||
|
Karl Matthias <karl@matthias.org>
|
||||||
|
Thomas Meson <zllak@hycik.org>
|
||||||
|
Martin Sucha <martin.sucha@kiwi.com>; <git@mm.ms47.eu>
|
||||||
|
Pavel Buchinchik <p.buchinchik@gmail.com>
|
@ -0,0 +1,78 @@
|
|||||||
|
# Contributing to gocql
|
||||||
|
|
||||||
|
**TL;DR** - this manifesto sets out the bare minimum requirements for submitting a patch to gocql.
|
||||||
|
|
||||||
|
This guide outlines the process of landing patches in gocql and the general approach to maintaining the code base.
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
The goal of the gocql project is to provide a stable and robust CQL driver for Go. gocql is a community driven project that is coordinated by a small team of core developers.
|
||||||
|
|
||||||
|
## Minimum Requirement Checklist
|
||||||
|
|
||||||
|
The following is a check list of requirements that need to be satisfied in order for us to merge your patch:
|
||||||
|
|
||||||
|
* You should raise a pull request to gocql/gocql on Github
|
||||||
|
* The pull request has a title that clearly summarizes the purpose of the patch
|
||||||
|
* The motivation behind the patch is clearly defined in the pull request summary
|
||||||
|
* Your name and email have been added to the `AUTHORS` file (for copyright purposes)
|
||||||
|
* The patch will merge cleanly
|
||||||
|
* The test coverage does not fall below the critical threshold (currently 64%)
|
||||||
|
* The merge commit passes the regression test suite on Travis
|
||||||
|
* `go fmt` has been applied to the submitted code
|
||||||
|
* Functional changes (i.e. new features or changed behavior) are appropriately documented, either as a godoc or in the README (non-functional changes such as bug fixes may not require documentation)
|
||||||
|
|
||||||
|
If there are any requirements that can't be reasonably satisfied, please state this either on the pull request or as part of discussion on the mailing list. Where appropriate, the core team may apply discretion and make an exception to these requirements.
|
||||||
|
|
||||||
|
## Beyond The Checklist
|
||||||
|
|
||||||
|
In addition to stating the hard requirements, there are a bunch of things that we consider when assessing changes to the library. These soft requirements are helpful pointers of how to get a patch landed quicker and with less fuss.
|
||||||
|
|
||||||
|
### General QA Approach
|
||||||
|
|
||||||
|
The gocql team needs to consider the ongoing maintainability of the library at all times. Patches that look like they will introduce maintenance issues for the team will not be accepted.
|
||||||
|
|
||||||
|
Your patch will get merged quicker if you have decent test cases that provide test coverage for the new behavior you wish to introduce.
|
||||||
|
|
||||||
|
Unit tests are good, integration tests are even better. An example of a unit test is `marshal_test.go` - this tests the serialization code in isolation. `cassandra_test.go` is an integration test suite that is executed against every version of Cassandra that gocql supports as part of the CI process on Travis.
|
||||||
|
|
||||||
|
That said, the point of writing tests is to provide a safety net to catch regressions, so there is no need to go overboard with tests. Remember that the more tests you write, the more code we will have to maintain. So there's a balance to strike there.
|
||||||
|
|
||||||
|
### When It's Too Difficult To Automate Testing
|
||||||
|
|
||||||
|
There are legitimate examples of where it is infeasible to write a regression test for a change. Never fear, we will still consider the patch and quite possibly accept the change without a test. The gocql team takes a pragmatic approach to testing. At the end of the day, you could be addressing an issue that is too difficult to reproduce in a test suite, but still occurs in a real production app. In this case, your production app is the test case, and we will have to trust that your change is good.
|
||||||
|
|
||||||
|
Examples of pull requests that have been accepted without tests include:
|
||||||
|
|
||||||
|
* https://github.com/gocql/gocql/pull/181 - this patch would otherwise require a multi-node cluster to be booted as part of the CI build
|
||||||
|
* https://github.com/gocql/gocql/pull/179 - this bug can only be reproduced under heavy load in certain circumstances
|
||||||
|
|
||||||
|
### Sign Off Procedure
|
||||||
|
|
||||||
|
Generally speaking, a pull request can get merged by any one of the core gocql team. If your change is minor, chances are that one team member will just go ahead and merge it there and then. As stated earlier, suitable test coverage will increase the likelihood that a single reviewer will assess and merge your change. If your change has no test coverage, or looks like it may have wider implications for the health and stability of the library, the reviewer may elect to refer the change to another team member to achieve consensus before proceeding. Therefore, the tighter and cleaner your patch is, the quicker it will go through the review process.
|
||||||
|
|
||||||
|
### Supported Features
|
||||||
|
|
||||||
|
gocql is a low level wire driver for Cassandra CQL. By and large, we would like to keep the functional scope of the library as narrow as possible. We think that gocql should be tight and focused, and we will be naturally skeptical of things that could just as easily be implemented in a higher layer. Inevitably you will come across something that could be implemented in a higher layer, save for a minor change to the core API. In this instance, please strike up a conversation with the gocql team. Chances are we will understand what you are trying to achieve and will try to accommodate this in a maintainable way.
|
||||||
|
|
||||||
|
### Longer Term Evolution
|
||||||
|
|
||||||
|
There are some long term plans for gocql that have to be taken into account when assessing changes. That said, gocql is ultimately a community driven project and we don't have a massive development budget, so sometimes the long term view might need to be de-prioritized ahead of short term changes.
|
||||||
|
|
||||||
|
## Officially Supported Server Versions
|
||||||
|
|
||||||
|
Currently, the officially supported versions of the Cassandra server include:
|
||||||
|
|
||||||
|
* 1.2.18
|
||||||
|
* 2.0.9
|
||||||
|
|
||||||
|
Chances are that gocql will work with many other versions. If you would like us to support a particular version of Cassandra, please start a conversation about what version you'd like us to consider. We are more likely to accept a new version if you help out by extending the regression suite to cover the new version to be supported.
|
||||||
|
|
||||||
|
## The Core Dev Team
|
||||||
|
|
||||||
|
The core development team includes:
|
||||||
|
|
||||||
|
* tux21b
|
||||||
|
* phillipCouto
|
||||||
|
* Zariel
|
||||||
|
* 0x6e6562
|
@ -0,0 +1,27 @@
|
|||||||
|
Copyright (c) 2016, The Gocql authors
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
* Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
@ -0,0 +1,215 @@
|
|||||||
|
gocql
|
||||||
|
=====
|
||||||
|
|
||||||
|
[](https://gitter.im/gocql/gocql?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||||
|
[](https://travis-ci.org/gocql/gocql)
|
||||||
|
[](https://godoc.org/github.com/gocql/gocql)
|
||||||
|
|
||||||
|
Package gocql implements a fast and robust Cassandra client for the
|
||||||
|
Go programming language.
|
||||||
|
|
||||||
|
Project Website: https://gocql.github.io/<br>
|
||||||
|
API documentation: https://godoc.org/github.com/gocql/gocql<br>
|
||||||
|
Discussions: https://groups.google.com/forum/#!forum/gocql
|
||||||
|
|
||||||
|
Supported Versions
|
||||||
|
------------------
|
||||||
|
|
||||||
|
The following matrix shows the versions of Go and Cassandra that are tested with the integration test suite as part of the CI build:
|
||||||
|
|
||||||
|
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
|
||||||
|
-------------| -------| ------| ---------
|
||||||
|
1.12 | yes | yes | yes
|
||||||
|
1.13 | yes | yes | yes
|
||||||
|
|
||||||
|
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
|
||||||
|
|
||||||
|
Sunsetting Model
|
||||||
|
----------------
|
||||||
|
|
||||||
|
In general, the gocql team will focus on supporting the current and previous versions of Go. gocql may still work with older versions of Go, but official support for these versions will have been sunset.
|
||||||
|
|
||||||
|
Installation
|
||||||
|
------------
|
||||||
|
|
||||||
|
go get github.com/gocql/gocql
|
||||||
|
|
||||||
|
|
||||||
|
Features
|
||||||
|
--------
|
||||||
|
|
||||||
|
* Modern Cassandra client using the native transport
|
||||||
|
* Automatic type conversions between Cassandra and Go
|
||||||
|
* Support for all common types including sets, lists and maps
|
||||||
|
* Custom types can implement a `Marshaler` and `Unmarshaler` interface
|
||||||
|
* Strict type conversions without any loss of precision
|
||||||
|
* Built-In support for UUIDs (version 1 and 4)
|
||||||
|
* Support for logged, unlogged and counter batches
|
||||||
|
* Cluster management
|
||||||
|
* Automatic reconnect on connection failures with exponential falloff
|
||||||
|
* Round robin distribution of queries to different hosts
|
||||||
|
* Round robin distribution of queries to different connections on a host
|
||||||
|
* Each connection can execute up to n concurrent queries (whereby n is the limit set by the protocol version the client chooses to use)
|
||||||
|
* Optional automatic discovery of nodes
|
||||||
|
* Policy based connection pool with token aware and round-robin policy implementations
|
||||||
|
* Support for password authentication
|
||||||
|
* Iteration over paged results with configurable page size
|
||||||
|
* Support for TLS/SSL
|
||||||
|
* Optional frame compression (using snappy)
|
||||||
|
* Automatic query preparation
|
||||||
|
* Support for query tracing
|
||||||
|
* Support for Cassandra 2.1+ [binary protocol version 3](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v3.spec)
|
||||||
|
* Support for up to 32768 streams
|
||||||
|
* Support for tuple types
|
||||||
|
* Support for client side timestamps by default
|
||||||
|
* Support for UDTs via a custom marshaller or struct tags
|
||||||
|
* Support for Cassandra 3.0+ [binary protocol version 4](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec)
|
||||||
|
* An API to access the schema metadata of a given keyspace
|
||||||
|
|
||||||
|
Performance
|
||||||
|
-----------
|
||||||
|
While the driver strives to be highly performant, there are cases where it is difficult to test and verify. The driver is built
|
||||||
|
with maintainability and code readability in mind first and then performance and features, as such every now and then performance
|
||||||
|
may degrade, if this occurs please report and issue and it will be looked at and remedied. The only time the driver copies data from
|
||||||
|
its read buffer is when it Unmarshal's data into supplied types.
|
||||||
|
|
||||||
|
Some tips for getting more performance from the driver:
|
||||||
|
* Use the TokenAware policy
|
||||||
|
* Use many goroutines when doing inserts, the driver is asynchronous but provides a synchronous API, it can execute many queries concurrently
|
||||||
|
* Tune query page size
|
||||||
|
* Reading data from the network to unmarshal will incur a large amount of allocations, this can adversely affect the garbage collector, tune `GOGC`
|
||||||
|
* Close iterators after use to recycle byte buffers
|
||||||
|
|
||||||
|
Important Default Keyspace Changes
|
||||||
|
----------------------------------
|
||||||
|
gocql no longer supports executing "use <keyspace>" statements to simplify the library. The user still has the
|
||||||
|
ability to define the default keyspace for connections but now the keyspace can only be defined before a
|
||||||
|
session is created. Queries can still access keyspaces by indicating the keyspace in the query:
|
||||||
|
```sql
|
||||||
|
SELECT * FROM example2.table;
|
||||||
|
```
|
||||||
|
|
||||||
|
Example of correct usage:
|
||||||
|
```go
|
||||||
|
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||||
|
cluster.Keyspace = "example"
|
||||||
|
...
|
||||||
|
session, err := cluster.CreateSession()
|
||||||
|
|
||||||
|
```
|
||||||
|
Example of incorrect usage:
|
||||||
|
```go
|
||||||
|
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||||
|
cluster.Keyspace = "example"
|
||||||
|
...
|
||||||
|
session, err := cluster.CreateSession()
|
||||||
|
|
||||||
|
if err = session.Query("use example2").Exec(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
This will result in an err being returned from the session.Query line as the user is trying to execute a "use"
|
||||||
|
statement.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
```go
|
||||||
|
/* Before you execute the program, Launch `cqlsh` and execute:
|
||||||
|
create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
|
||||||
|
create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id));
|
||||||
|
create index on example.tweet(timeline);
|
||||||
|
*/
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/gocql/gocql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// connect to the cluster
|
||||||
|
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||||
|
cluster.Keyspace = "example"
|
||||||
|
cluster.Consistency = gocql.Quorum
|
||||||
|
session, _ := cluster.CreateSession()
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
// insert a tweet
|
||||||
|
if err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||||
|
"me", gocql.TimeUUID(), "hello world").Exec(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var id gocql.UUID
|
||||||
|
var text string
|
||||||
|
|
||||||
|
/* Search for a specific set of records whose 'timeline' column matches
|
||||||
|
* the value 'me'. The secondary index that we created earlier will be
|
||||||
|
* used for optimizing the search */
|
||||||
|
if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
|
||||||
|
"me").Consistency(gocql.One).Scan(&id, &text); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Println("Tweet:", id, text)
|
||||||
|
|
||||||
|
// list all tweets
|
||||||
|
iter := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, "me").Iter()
|
||||||
|
for iter.Scan(&id, &text) {
|
||||||
|
fmt.Println("Tweet:", id, text)
|
||||||
|
}
|
||||||
|
if err := iter.Close(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Data Binding
|
||||||
|
------------
|
||||||
|
|
||||||
|
There are various ways to bind application level data structures to CQL statements:
|
||||||
|
|
||||||
|
* You can write the data binding by hand, as outlined in the Tweet example. This provides you with the greatest flexibility, but it does mean that you need to keep your application code in sync with your Cassandra schema.
|
||||||
|
* You can dynamically marshal an entire query result into an `[]map[string]interface{}` using the `SliceMap()` API. This returns a slice of row maps keyed by CQL column names. This method requires no special interaction with the gocql API, but it does require your application to be able to deal with a key value view of your data.
|
||||||
|
* As a refinement on the `SliceMap()` API you can also call `MapScan()` which returns `map[string]interface{}` instances in a row by row fashion.
|
||||||
|
* The `Bind()` API provides a client app with a low level mechanism to introspect query meta data and extract appropriate field values from application level data structures.
|
||||||
|
* The [gocqlx](https://github.com/scylladb/gocqlx) package is an idiomatic extension to gocql that provides usability features. With gocqlx you can bind the query parameters from maps and structs, use named query parameters (:identifier) and scan the query results into structs and slices. It comes with a fluent and flexible CQL query builder that supports full CQL spec, including BATCH statements and custom functions.
|
||||||
|
* Building on top of the gocql driver, [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
|
||||||
|
* Another external project that layers on top of gocql is [cqlc](http://relops.com/cqlc) which generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
|
||||||
|
* [gocassa](https://github.com/hailocab/gocassa) is an external project that layers on top of gocql to provide convenient query building and data binding.
|
||||||
|
* [gocqltable](https://github.com/kristoiv/gocqltable) provides an ORM-style convenience layer to make CRUD operations with gocql easier.
|
||||||
|
|
||||||
|
Ecosystem
|
||||||
|
---------
|
||||||
|
|
||||||
|
The following community maintained tools are known to integrate with gocql:
|
||||||
|
|
||||||
|
* [gocqlx](https://github.com/scylladb/gocqlx) is a gocql extension that automates data binding, adds named queries support, provides flexible query builders and plays well with gocql.
|
||||||
|
* [journey](https://github.com/db-journey/journey) is a migration tool with Cassandra support.
|
||||||
|
* [negronicql](https://github.com/mikebthun/negronicql) is gocql middleware for Negroni.
|
||||||
|
* [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
|
||||||
|
* [cqlc](http://relops.com/cqlc) generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
|
||||||
|
* [gocassa](https://github.com/hailocab/gocassa) provides query building, adds data binding, and provides easy-to-use "recipe" tables for common query use-cases.
|
||||||
|
* [gocqltable](https://github.com/kristoiv/gocqltable) is a wrapper around gocql that aims to simplify common operations.
|
||||||
|
* [gockle](https://github.com/willfaught/gockle) provides simple, mockable interfaces that wrap gocql types
|
||||||
|
* [scylladb](https://github.com/scylladb/scylla) is a fast Apache Cassandra-compatible NoSQL database
|
||||||
|
* [go-cql-driver](https://github.com/MichaelS11/go-cql-driver) is an CQL driver conforming to the built-in database/sql interface. It is good for simple use cases where the database/sql interface is wanted. The CQL driver is a wrapper around this project.
|
||||||
|
|
||||||
|
Other Projects
|
||||||
|
--------------
|
||||||
|
|
||||||
|
* [gocqldriver](https://github.com/tux21b/gocqldriver) is the predecessor of gocql based on Go's `database/sql` package. This project isn't maintained anymore, because Cassandra wasn't a good fit for the traditional `database/sql` API. Use this package instead.
|
||||||
|
|
||||||
|
SEO
|
||||||
|
---
|
||||||
|
|
||||||
|
For some reason, when you Google `golang cassandra`, this project doesn't feature very highly in the result list. But if you Google `go cassandra`, then we're a bit higher up the list. So this is note to try to convince Google that golang is an alias for Go.
|
||||||
|
|
||||||
|
License
|
||||||
|
-------
|
||||||
|
|
||||||
|
> Copyright (c) 2012-2016 The gocql Authors. All rights reserved.
|
||||||
|
> Use of this source code is governed by a BSD-style
|
||||||
|
> license that can be found in the LICENSE file.
|
@ -0,0 +1,26 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
// AddressTranslator provides a way to translate node addresses (and ports) that are
|
||||||
|
// discovered or received as a node event. This can be useful in an ec2 environment,
|
||||||
|
// for instance, to translate public IPs to private IPs.
|
||||||
|
type AddressTranslator interface {
|
||||||
|
// Translate will translate the provided address and/or port to another
|
||||||
|
// address and/or port. If no translation is possible, Translate will return the
|
||||||
|
// address and port provided to it.
|
||||||
|
Translate(addr net.IP, port int) (net.IP, int)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)
|
||||||
|
|
||||||
|
func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
|
||||||
|
return fn(addr, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
|
||||||
|
func IdentityTranslator() AddressTranslator {
|
||||||
|
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
|
||||||
|
return addr, port
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,211 @@
|
|||||||
|
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PoolConfig configures the connection pool used by the driver, it defaults to
|
||||||
|
// using a round-robin host selection policy and a round-robin connection selection
|
||||||
|
// policy for each host.
|
||||||
|
type PoolConfig struct {
|
||||||
|
// HostSelectionPolicy sets the policy for selecting which host to use for a
|
||||||
|
// given query (default: RoundRobinHostPolicy())
|
||||||
|
HostSelectionPolicy HostSelectionPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PoolConfig) buildPool(session *Session) *policyConnPool {
|
||||||
|
return newPolicyConnPool(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClusterConfig is a struct to configure the default cluster implementation
|
||||||
|
// of gocql. It has a variety of attributes that can be used to modify the
|
||||||
|
// behavior to fit the most common use cases. Applications that require a
|
||||||
|
// different setup must implement their own cluster.
|
||||||
|
type ClusterConfig struct {
|
||||||
|
// addresses for the initial connections. It is recommended to use the value set in
|
||||||
|
// the Cassandra config for broadcast_address or listen_address, an IP address not
|
||||||
|
// a domain name. This is because events from Cassandra will use the configured IP
|
||||||
|
// address, which is used to index connected hosts. If the domain name specified
|
||||||
|
// resolves to more than 1 IP address then the driver may connect multiple times to
|
||||||
|
// the same host, and will not mark the node being down or up from events.
|
||||||
|
Hosts []string
|
||||||
|
CQLVersion string // CQL version (default: 3.0.0)
|
||||||
|
|
||||||
|
// ProtoVersion sets the version of the native protocol to use, this will
|
||||||
|
// enable features in the driver for specific protocol versions, generally this
|
||||||
|
// should be set to a known version (2,3,4) for the cluster being connected to.
|
||||||
|
//
|
||||||
|
// If it is 0 or unset (the default) then the driver will attempt to discover the
|
||||||
|
// highest supported protocol for the cluster. In clusters with nodes of different
|
||||||
|
// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
|
||||||
|
ProtoVersion int
|
||||||
|
Timeout time.Duration // connection timeout (default: 600ms)
|
||||||
|
ConnectTimeout time.Duration // initial connection timeout, used during initial dial to server (default: 600ms)
|
||||||
|
Port int // port (default: 9042)
|
||||||
|
Keyspace string // initial keyspace (optional)
|
||||||
|
NumConns int // number of connections per host (default: 2)
|
||||||
|
Consistency Consistency // default consistency level (default: Quorum)
|
||||||
|
Compressor Compressor // compression algorithm (default: nil)
|
||||||
|
Authenticator Authenticator // authenticator (default: nil)
|
||||||
|
AuthProvider func(h *HostInfo) (Authenticator, error) // an authenticator factory. Can be used to create alternative authenticators (default: nil)
|
||||||
|
RetryPolicy RetryPolicy // Default retry policy to use for queries (default: 0)
|
||||||
|
ConvictionPolicy ConvictionPolicy // Decide whether to mark host as down based on the error and host info (default: SimpleConvictionPolicy)
|
||||||
|
ReconnectionPolicy ReconnectionPolicy // Default reconnection policy to use for reconnecting before trying to mark host as down (default: see below)
|
||||||
|
SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0)
|
||||||
|
MaxPreparedStmts int // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
|
||||||
|
MaxRoutingKeyInfo int // Sets the maximum cache size for query info about statements for each session (default: 1000)
|
||||||
|
PageSize int // Default page size to use for created sessions (default: 5000)
|
||||||
|
SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset)
|
||||||
|
SslOpts *SslOptions
|
||||||
|
DefaultTimestamp bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
|
||||||
|
// PoolConfig configures the underlying connection pool, allowing the
|
||||||
|
// configuration of host selection and connection selection policies.
|
||||||
|
PoolConfig PoolConfig
|
||||||
|
|
||||||
|
// If not zero, gocql attempt to reconnect known DOWN nodes in every ReconnectInterval.
|
||||||
|
ReconnectInterval time.Duration
|
||||||
|
|
||||||
|
// The maximum amount of time to wait for schema agreement in a cluster after
|
||||||
|
// receiving a schema change frame. (default: 60s)
|
||||||
|
MaxWaitSchemaAgreement time.Duration
|
||||||
|
|
||||||
|
// HostFilter will filter all incoming events for host, any which don't pass
|
||||||
|
// the filter will be ignored. If set will take precedence over any options set
|
||||||
|
// via Discovery
|
||||||
|
HostFilter HostFilter
|
||||||
|
|
||||||
|
// AddressTranslator will translate addresses found on peer discovery and/or
|
||||||
|
// node change events.
|
||||||
|
AddressTranslator AddressTranslator
|
||||||
|
|
||||||
|
// If IgnorePeerAddr is true and the address in system.peers does not match
|
||||||
|
// the supplied host by either initial hosts or discovered via events then the
|
||||||
|
// host will be replaced with the supplied address.
|
||||||
|
//
|
||||||
|
// For example if an event comes in with host=10.0.0.1 but when looking up that
|
||||||
|
// address in system.local or system.peers returns 127.0.0.1, the peer will be
|
||||||
|
// set to 10.0.0.1 which is what will be used to connect to.
|
||||||
|
IgnorePeerAddr bool
|
||||||
|
|
||||||
|
// If DisableInitialHostLookup then the driver will not attempt to get host info
|
||||||
|
// from the system.peers table, this will mean that the driver will connect to
|
||||||
|
// hosts supplied and will not attempt to lookup the hosts information, this will
|
||||||
|
// mean that data_centre, rack and token information will not be available and as
|
||||||
|
// such host filtering and token aware query routing will not be available.
|
||||||
|
DisableInitialHostLookup bool
|
||||||
|
|
||||||
|
// Configure events the driver will register for
|
||||||
|
Events struct {
|
||||||
|
// disable registering for status events (node up/down)
|
||||||
|
DisableNodeStatusEvents bool
|
||||||
|
// disable registering for topology events (node added/removed/moved)
|
||||||
|
DisableTopologyEvents bool
|
||||||
|
// disable registering for schema events (keyspace/table/function removed/created/updated)
|
||||||
|
DisableSchemaEvents bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableSkipMetadata will override the internal result metadata cache so that the driver does not
|
||||||
|
// send skip_metadata for queries, this means that the result will always contain
|
||||||
|
// the metadata to parse the rows and will not reuse the metadata from the prepared
|
||||||
|
// statement.
|
||||||
|
//
|
||||||
|
// See https://issues.apache.org/jira/browse/CASSANDRA-10786
|
||||||
|
DisableSkipMetadata bool
|
||||||
|
|
||||||
|
// QueryObserver will set the provided query observer on all queries created from this session.
|
||||||
|
// Use it to collect metrics / stats from queries by providing an implementation of QueryObserver.
|
||||||
|
QueryObserver QueryObserver
|
||||||
|
|
||||||
|
// BatchObserver will set the provided batch observer on all queries created from this session.
|
||||||
|
// Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver.
|
||||||
|
BatchObserver BatchObserver
|
||||||
|
|
||||||
|
// ConnectObserver will set the provided connect observer on all queries
|
||||||
|
// created from this session.
|
||||||
|
ConnectObserver ConnectObserver
|
||||||
|
|
||||||
|
// FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session.
|
||||||
|
// Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver.
|
||||||
|
FrameHeaderObserver FrameHeaderObserver
|
||||||
|
|
||||||
|
// Default idempotence for queries
|
||||||
|
DefaultIdempotence bool
|
||||||
|
|
||||||
|
// The time to wait for frames before flushing the frames connection to Cassandra.
|
||||||
|
// Can help reduce syscall overhead by making less calls to write. Set to 0 to
|
||||||
|
// disable.
|
||||||
|
//
|
||||||
|
// (default: 200 microseconds)
|
||||||
|
WriteCoalesceWaitTime time.Duration
|
||||||
|
|
||||||
|
// internal config for testing
|
||||||
|
disableControlConn bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCluster generates a new config for the default cluster implementation.
|
||||||
|
//
|
||||||
|
// The supplied hosts are used to initially connect to the cluster then the rest of
|
||||||
|
// the ring will be automatically discovered. It is recommended to use the value set in
|
||||||
|
// the Cassandra config for broadcast_address or listen_address, an IP address not
|
||||||
|
// a domain name. This is because events from Cassandra will use the configured IP
|
||||||
|
// address, which is used to index connected hosts. If the domain name specified
|
||||||
|
// resolves to more than 1 IP address then the driver may connect multiple times to
|
||||||
|
// the same host, and will not mark the node being down or up from events.
|
||||||
|
func NewCluster(hosts ...string) *ClusterConfig {
|
||||||
|
cfg := &ClusterConfig{
|
||||||
|
Hosts: hosts,
|
||||||
|
CQLVersion: "3.0.0",
|
||||||
|
Timeout: 600 * time.Millisecond,
|
||||||
|
ConnectTimeout: 600 * time.Millisecond,
|
||||||
|
Port: 9042,
|
||||||
|
NumConns: 2,
|
||||||
|
Consistency: Quorum,
|
||||||
|
MaxPreparedStmts: defaultMaxPreparedStmts,
|
||||||
|
MaxRoutingKeyInfo: 1000,
|
||||||
|
PageSize: 5000,
|
||||||
|
DefaultTimestamp: true,
|
||||||
|
MaxWaitSchemaAgreement: 60 * time.Second,
|
||||||
|
ReconnectInterval: 60 * time.Second,
|
||||||
|
ConvictionPolicy: &SimpleConvictionPolicy{},
|
||||||
|
ReconnectionPolicy: &ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second},
|
||||||
|
WriteCoalesceWaitTime: 200 * time.Microsecond,
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession initializes the cluster based on this config and returns a
|
||||||
|
// session object that can be used to interact with the database.
|
||||||
|
func (cfg *ClusterConfig) CreateSession() (*Session, error) {
|
||||||
|
return NewSession(*cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateAddressPort is a helper method that will use the given AddressTranslator
|
||||||
|
// if defined, to translate the given address and port into a possibly new address
|
||||||
|
// and port, If no AddressTranslator or if an error occurs, the given address and
|
||||||
|
// port will be returned.
|
||||||
|
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
|
||||||
|
if cfg.AddressTranslator == nil || len(addr) == 0 {
|
||||||
|
return addr, port
|
||||||
|
}
|
||||||
|
newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
|
||||||
|
}
|
||||||
|
return newAddr, newPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *ClusterConfig) filterHost(host *HostInfo) bool {
|
||||||
|
return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host))
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoHosts = errors.New("no hosts provided")
|
||||||
|
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
|
||||||
|
ErrHostQueryFailed = errors.New("unable to populate Hosts")
|
||||||
|
)
|
@ -0,0 +1,28 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/golang/snappy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Compressor interface {
|
||||||
|
Name() string
|
||||||
|
Encode(data []byte) ([]byte, error)
|
||||||
|
Decode(data []byte) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SnappyCompressor implements the Compressor interface and can be used to
|
||||||
|
// compress incoming and outgoing frames. The snappy compression algorithm
|
||||||
|
// aims for very high speeds and reasonable compression.
|
||||||
|
type SnappyCompressor struct{}
|
||||||
|
|
||||||
|
func (s SnappyCompressor) Name() string {
|
||||||
|
return "snappy"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
|
||||||
|
return snappy.Encode(nil, data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
|
||||||
|
return snappy.Decode(nil, data)
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,581 @@
|
|||||||
|
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// interface to implement to receive the host information
|
||||||
|
type SetHosts interface {
|
||||||
|
SetHosts(hosts []*HostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// interface to implement to receive the partitioner value
|
||||||
|
type SetPartitioner interface {
|
||||||
|
SetPartitioner(partitioner string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
||||||
|
if sslOpts.Config == nil {
|
||||||
|
sslOpts.Config = &tls.Config{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ca cert is optional
|
||||||
|
if sslOpts.CaPath != "" {
|
||||||
|
if sslOpts.RootCAs == nil {
|
||||||
|
sslOpts.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := ioutil.ReadFile(sslOpts.CaPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sslOpts.RootCAs.AppendCertsFromPEM(pem) {
|
||||||
|
return nil, errors.New("connectionpool: failed parsing or CA certs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
|
||||||
|
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
|
||||||
|
}
|
||||||
|
sslOpts.Certificates = append(sslOpts.Certificates, mycert)
|
||||||
|
}
|
||||||
|
|
||||||
|
sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification
|
||||||
|
|
||||||
|
// return clone to avoid race
|
||||||
|
return sslOpts.Config.Clone(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type policyConnPool struct {
|
||||||
|
session *Session
|
||||||
|
|
||||||
|
port int
|
||||||
|
numConns int
|
||||||
|
keyspace string
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
hostConnPools map[string]*hostConnPool
|
||||||
|
|
||||||
|
endpoints []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(zariel): move tls config setup into session init.
|
||||||
|
if cfg.SslOpts != nil {
|
||||||
|
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ConnConfig{
|
||||||
|
ProtoVersion: cfg.ProtoVersion,
|
||||||
|
CQLVersion: cfg.CQLVersion,
|
||||||
|
Timeout: cfg.Timeout,
|
||||||
|
ConnectTimeout: cfg.ConnectTimeout,
|
||||||
|
Compressor: cfg.Compressor,
|
||||||
|
Authenticator: cfg.Authenticator,
|
||||||
|
AuthProvider: cfg.AuthProvider,
|
||||||
|
Keepalive: cfg.SocketKeepalive,
|
||||||
|
tlsConfig: tlsConfig,
|
||||||
|
disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS.
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPolicyConnPool(session *Session) *policyConnPool {
|
||||||
|
// create the pool
|
||||||
|
pool := &policyConnPool{
|
||||||
|
session: session,
|
||||||
|
port: session.cfg.Port,
|
||||||
|
numConns: session.cfg.NumConns,
|
||||||
|
keyspace: session.cfg.Keyspace,
|
||||||
|
hostConnPools: map[string]*hostConnPool{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.endpoints = make([]string, len(session.cfg.Hosts))
|
||||||
|
copy(pool.endpoints, session.cfg.Hosts)
|
||||||
|
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
toRemove := make(map[string]struct{})
|
||||||
|
for addr := range p.hostConnPools {
|
||||||
|
toRemove[addr] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
pools := make(chan *hostConnPool)
|
||||||
|
createCount := 0
|
||||||
|
for _, host := range hosts {
|
||||||
|
if !host.IsUp() {
|
||||||
|
// don't create a connection pool for a down host
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
if _, exists := p.hostConnPools[ip]; exists {
|
||||||
|
// still have this host, so don't remove it
|
||||||
|
delete(toRemove, ip)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
createCount++
|
||||||
|
go func(host *HostInfo) {
|
||||||
|
// create a connection pool for the host
|
||||||
|
pools <- newHostConnPool(
|
||||||
|
p.session,
|
||||||
|
host,
|
||||||
|
p.port,
|
||||||
|
p.numConns,
|
||||||
|
p.keyspace,
|
||||||
|
)
|
||||||
|
}(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add created pools
|
||||||
|
for createCount > 0 {
|
||||||
|
pool := <-pools
|
||||||
|
createCount--
|
||||||
|
if pool.Size() > 0 {
|
||||||
|
// add pool only if there a connections available
|
||||||
|
p.hostConnPools[string(pool.host.ConnectAddress())] = pool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for addr := range toRemove {
|
||||||
|
pool := p.hostConnPools[addr]
|
||||||
|
delete(p.hostConnPools, addr)
|
||||||
|
go pool.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) Size() int {
|
||||||
|
p.mu.RLock()
|
||||||
|
count := 0
|
||||||
|
for _, pool := range p.hostConnPools {
|
||||||
|
count += pool.Size()
|
||||||
|
}
|
||||||
|
p.mu.RUnlock()
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
p.mu.RLock()
|
||||||
|
pool, ok = p.hostConnPools[ip]
|
||||||
|
p.mu.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) Close() {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
// close the pools
|
||||||
|
for addr, pool := range p.hostConnPools {
|
||||||
|
delete(p.hostConnPools, addr)
|
||||||
|
pool.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) addHost(host *HostInfo) {
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
p.mu.Lock()
|
||||||
|
pool, ok := p.hostConnPools[ip]
|
||||||
|
if !ok {
|
||||||
|
pool = newHostConnPool(
|
||||||
|
p.session,
|
||||||
|
host,
|
||||||
|
host.Port(), // TODO: if port == 0 use pool.port?
|
||||||
|
p.numConns,
|
||||||
|
p.keyspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
p.hostConnPools[ip] = pool
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
pool.fill()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) removeHost(ip net.IP) {
|
||||||
|
k := ip.String()
|
||||||
|
p.mu.Lock()
|
||||||
|
pool, ok := p.hostConnPools[k]
|
||||||
|
if !ok {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(p.hostConnPools, k)
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
go pool.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) hostUp(host *HostInfo) {
|
||||||
|
// TODO(zariel): have a set of up hosts and down hosts, we can internally
|
||||||
|
// detect down hosts, then try to reconnect to them.
|
||||||
|
p.addHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *policyConnPool) hostDown(ip net.IP) {
|
||||||
|
// TODO(zariel): mark host as down so we can try to connect to it later, for
|
||||||
|
// now just treat it has removed.
|
||||||
|
p.removeHost(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostConnPool is a connection pool for a single host.
|
||||||
|
// Connection selection is based on a provided ConnSelectionPolicy
|
||||||
|
type hostConnPool struct {
|
||||||
|
session *Session
|
||||||
|
host *HostInfo
|
||||||
|
port int
|
||||||
|
addr string
|
||||||
|
size int
|
||||||
|
keyspace string
|
||||||
|
// protection for conns, closed, filling
|
||||||
|
mu sync.RWMutex
|
||||||
|
conns []*Conn
|
||||||
|
closed bool
|
||||||
|
filling bool
|
||||||
|
|
||||||
|
pos uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hostConnPool) String() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
|
||||||
|
h.filling, h.closed, len(h.conns), h.size, h.host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHostConnPool(session *Session, host *HostInfo, port, size int,
|
||||||
|
keyspace string) *hostConnPool {
|
||||||
|
|
||||||
|
pool := &hostConnPool{
|
||||||
|
session: session,
|
||||||
|
host: host,
|
||||||
|
port: port,
|
||||||
|
addr: (&net.TCPAddr{IP: host.ConnectAddress(), Port: host.Port()}).String(),
|
||||||
|
size: size,
|
||||||
|
keyspace: keyspace,
|
||||||
|
conns: make([]*Conn, 0, size),
|
||||||
|
filling: false,
|
||||||
|
closed: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// the pool is not filled or connected
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pick a connection from this connection pool for the given query.
|
||||||
|
func (pool *hostConnPool) Pick() *Conn {
|
||||||
|
pool.mu.RLock()
|
||||||
|
defer pool.mu.RUnlock()
|
||||||
|
|
||||||
|
if pool.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
size := len(pool.conns)
|
||||||
|
if size < pool.size {
|
||||||
|
// try to fill the pool
|
||||||
|
go pool.fill()
|
||||||
|
|
||||||
|
if size == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := int(atomic.AddUint32(&pool.pos, 1) - 1)
|
||||||
|
|
||||||
|
var (
|
||||||
|
leastBusyConn *Conn
|
||||||
|
streamsAvailable int
|
||||||
|
)
|
||||||
|
|
||||||
|
// find the conn which has the most available streams, this is racy
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
conn := pool.conns[(pos+i)%size]
|
||||||
|
if streams := conn.AvailableStreams(); streams > streamsAvailable {
|
||||||
|
leastBusyConn = conn
|
||||||
|
streamsAvailable = streams
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return leastBusyConn
|
||||||
|
}
|
||||||
|
|
||||||
|
//Size returns the number of connections currently active in the pool
|
||||||
|
func (pool *hostConnPool) Size() int {
|
||||||
|
pool.mu.RLock()
|
||||||
|
defer pool.mu.RUnlock()
|
||||||
|
|
||||||
|
return len(pool.conns)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Close the connection pool
|
||||||
|
func (pool *hostConnPool) Close() {
|
||||||
|
pool.mu.Lock()
|
||||||
|
|
||||||
|
if pool.closed {
|
||||||
|
pool.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pool.closed = true
|
||||||
|
|
||||||
|
// ensure we dont try to reacquire the lock in handleError
|
||||||
|
// TODO: improve this as the following can happen
|
||||||
|
// 1) we have locked pool.mu write lock
|
||||||
|
// 2) conn.Close calls conn.closeWithError(nil)
|
||||||
|
// 3) conn.closeWithError calls conn.Close() which returns an error
|
||||||
|
// 4) conn.closeWithError calls pool.HandleError with the error from conn.Close
|
||||||
|
// 5) pool.HandleError tries to lock pool.mu
|
||||||
|
// deadlock
|
||||||
|
|
||||||
|
// empty the pool
|
||||||
|
conns := pool.conns
|
||||||
|
pool.conns = nil
|
||||||
|
|
||||||
|
pool.mu.Unlock()
|
||||||
|
|
||||||
|
// close the connections
|
||||||
|
for _, conn := range conns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill the connection pool
|
||||||
|
func (pool *hostConnPool) fill() {
|
||||||
|
pool.mu.RLock()
|
||||||
|
// avoid filling a closed pool, or concurrent filling
|
||||||
|
if pool.closed || pool.filling {
|
||||||
|
pool.mu.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine the filling work to be done
|
||||||
|
startCount := len(pool.conns)
|
||||||
|
fillCount := pool.size - startCount
|
||||||
|
|
||||||
|
// avoid filling a full (or overfull) pool
|
||||||
|
if fillCount <= 0 {
|
||||||
|
pool.mu.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// switch from read to write lock
|
||||||
|
pool.mu.RUnlock()
|
||||||
|
pool.mu.Lock()
|
||||||
|
|
||||||
|
// double check everything since the lock was released
|
||||||
|
startCount = len(pool.conns)
|
||||||
|
fillCount = pool.size - startCount
|
||||||
|
if pool.closed || pool.filling || fillCount <= 0 {
|
||||||
|
// looks like another goroutine already beat this
|
||||||
|
// goroutine to the filling
|
||||||
|
pool.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ok fill the pool
|
||||||
|
pool.filling = true
|
||||||
|
|
||||||
|
// allow others to access the pool while filling
|
||||||
|
pool.mu.Unlock()
|
||||||
|
// only this goroutine should make calls to fill/empty the pool at this
|
||||||
|
// point until after this routine or its subordinates calls
|
||||||
|
// fillingStopped
|
||||||
|
|
||||||
|
// fill only the first connection synchronously
|
||||||
|
if startCount == 0 {
|
||||||
|
err := pool.connect()
|
||||||
|
pool.logConnectErr(err)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// probably unreachable host
|
||||||
|
pool.fillingStopped(true)
|
||||||
|
|
||||||
|
// this is call with the connection pool mutex held, this call will
|
||||||
|
// then recursively try to lock it again. FIXME
|
||||||
|
if pool.session.cfg.ConvictionPolicy.AddFailure(err, pool.host) {
|
||||||
|
go pool.session.handleNodeDown(pool.host.ConnectAddress(), pool.port)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// filled one
|
||||||
|
fillCount--
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill the rest of the pool asynchronously
|
||||||
|
go func() {
|
||||||
|
err := pool.connectMany(fillCount)
|
||||||
|
|
||||||
|
// mark the end of filling
|
||||||
|
pool.fillingStopped(err != nil)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *hostConnPool) logConnectErr(err error) {
|
||||||
|
if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
|
||||||
|
// connection refused
|
||||||
|
// these are typical during a node outage so avoid log spam.
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("unable to dial %q: %v\n", pool.host.ConnectAddress(), err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
// unexpected error
|
||||||
|
Logger.Printf("error: failed to connect to %s due to error: %v", pool.addr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// transition back to a not-filling state.
|
||||||
|
func (pool *hostConnPool) fillingStopped(hadError bool) {
|
||||||
|
if hadError {
|
||||||
|
// wait for some time to avoid back-to-back filling
|
||||||
|
// this provides some time between failed attempts
|
||||||
|
// to fill the pool for the host to recover
|
||||||
|
time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.mu.Lock()
|
||||||
|
pool.filling = false
|
||||||
|
pool.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectMany creates new connections concurrent.
|
||||||
|
func (pool *hostConnPool) connectMany(count int) error {
|
||||||
|
if count == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
wg sync.WaitGroup
|
||||||
|
mu sync.Mutex
|
||||||
|
connectErr error
|
||||||
|
)
|
||||||
|
wg.Add(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := pool.connect()
|
||||||
|
pool.logConnectErr(err)
|
||||||
|
if err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
connectErr = err
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
// wait for all connections are done
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return connectErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new connection to the host and add it to the pool
|
||||||
|
func (pool *hostConnPool) connect() (err error) {
|
||||||
|
// TODO: provide a more robust connection retry mechanism, we should also
|
||||||
|
// be able to detect hosts that come up by trying to connect to downed ones.
|
||||||
|
// try to connect
|
||||||
|
var conn *Conn
|
||||||
|
reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
|
||||||
|
for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
|
||||||
|
conn, err = pool.session.connect(pool.session.ctx, pool.host, pool)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if opErr, isOpErr := err.(*net.OpError); isOpErr {
|
||||||
|
// if the error is not a temporary error (ex: network unreachable) don't
|
||||||
|
// retry
|
||||||
|
if !opErr.Temporary() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("connection failed %q: %v, reconnecting with %T\n",
|
||||||
|
pool.host.ConnectAddress(), err, reconnectionPolicy)
|
||||||
|
}
|
||||||
|
time.Sleep(reconnectionPolicy.GetInterval(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pool.keyspace != "" {
|
||||||
|
// set the keyspace
|
||||||
|
if err = conn.UseKeyspace(pool.keyspace); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the Conn to the pool
|
||||||
|
pool.mu.Lock()
|
||||||
|
defer pool.mu.Unlock()
|
||||||
|
|
||||||
|
if pool.closed {
|
||||||
|
conn.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.conns = append(pool.conns, conn)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle any error from a Conn
|
||||||
|
func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
|
||||||
|
if !closed {
|
||||||
|
// still an open connection, so continue using it
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: track the number of errors per host and detect when a host is dead,
|
||||||
|
// then also have something which can detect when a host comes back.
|
||||||
|
pool.mu.Lock()
|
||||||
|
defer pool.mu.Unlock()
|
||||||
|
|
||||||
|
if pool.closed {
|
||||||
|
// pool closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the connection index
|
||||||
|
for i, candidate := range pool.conns {
|
||||||
|
if candidate == conn {
|
||||||
|
// remove the connection, not preserving order
|
||||||
|
pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
|
||||||
|
|
||||||
|
// lost a connection, so fill the pool
|
||||||
|
go pool.fill()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,488 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
randr *rand.Rand
|
||||||
|
mutRandr sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
b := make([]byte, 4)
|
||||||
|
if _, err := crand.Read(b); err != nil {
|
||||||
|
panic(fmt.Sprintf("unable to seed random number generator: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
randr = rand.New(rand.NewSource(int64(readInt(b))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the atomic variable is aligned to a 64bit boundary
|
||||||
|
// so that atomic operations can be applied on 32bit architectures.
|
||||||
|
type controlConn struct {
|
||||||
|
started int32
|
||||||
|
reconnecting int32
|
||||||
|
|
||||||
|
session *Session
|
||||||
|
conn atomic.Value
|
||||||
|
|
||||||
|
retry RetryPolicy
|
||||||
|
|
||||||
|
quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createControlConn(session *Session) *controlConn {
|
||||||
|
control := &controlConn{
|
||||||
|
session: session,
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
retry: &SimpleRetryPolicy{NumRetries: 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
control.conn.Store((*connHost)(nil))
|
||||||
|
|
||||||
|
return control
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) heartBeat() {
|
||||||
|
if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sleepTime := 1 * time.Second
|
||||||
|
timer := time.NewTimer(sleepTime)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
timer.Reset(sleepTime)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.quit:
|
||||||
|
return
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.writeFrame(&writeOptionsFrame{})
|
||||||
|
if err != nil {
|
||||||
|
goto reconn
|
||||||
|
}
|
||||||
|
|
||||||
|
switch resp.(type) {
|
||||||
|
case *supportedFrame:
|
||||||
|
// Everything ok
|
||||||
|
sleepTime = 5 * time.Second
|
||||||
|
continue
|
||||||
|
case error:
|
||||||
|
goto reconn
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
|
||||||
|
}
|
||||||
|
|
||||||
|
reconn:
|
||||||
|
// try to connect a bit faster
|
||||||
|
sleepTime = 1 * time.Second
|
||||||
|
c.reconnect(true)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true"
|
||||||
|
|
||||||
|
func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
|
||||||
|
var port int
|
||||||
|
host, portStr, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr
|
||||||
|
port = defaultPort
|
||||||
|
} else {
|
||||||
|
port, err = strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var hosts []*HostInfo
|
||||||
|
|
||||||
|
// Check if host is a literal IP address
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
|
||||||
|
return hosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up host in DNS
|
||||||
|
ips, err := LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if len(ips) == 0 {
|
||||||
|
return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter to v4 addresses if any present
|
||||||
|
if hostLookupPreferV4 {
|
||||||
|
var preferredIPs []net.IP
|
||||||
|
for _, v := range ips {
|
||||||
|
if v4 := v.To4(); v4 != nil {
|
||||||
|
preferredIPs = append(preferredIPs, v4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(preferredIPs) != 0 {
|
||||||
|
ips = preferredIPs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
|
||||||
|
}
|
||||||
|
|
||||||
|
return hosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shuffleHosts(hosts []*HostInfo) []*HostInfo {
|
||||||
|
shuffled := make([]*HostInfo, len(hosts))
|
||||||
|
copy(shuffled, hosts)
|
||||||
|
|
||||||
|
mutRandr.Lock()
|
||||||
|
randr.Shuffle(len(hosts), func(i, j int) {
|
||||||
|
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
|
||||||
|
})
|
||||||
|
mutRandr.Unlock()
|
||||||
|
|
||||||
|
return shuffled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
|
||||||
|
// shuffle endpoints so not all drivers will connect to the same initial
|
||||||
|
// node.
|
||||||
|
shuffled := shuffleHosts(endpoints)
|
||||||
|
|
||||||
|
cfg := *c.session.connCfg
|
||||||
|
cfg.disableCoalesce = true
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, host := range shuffled {
|
||||||
|
var conn *Conn
|
||||||
|
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
|
||||||
|
if err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is going to be version dependant and a nightmare to maintain :(
|
||||||
|
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
|
||||||
|
|
||||||
|
func parseProtocolFromError(err error) int {
|
||||||
|
// I really wish this had the actual info in the error frame...
|
||||||
|
matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
|
||||||
|
if len(matches) != 1 || len(matches[0]) != 2 {
|
||||||
|
if verr, ok := err.(*protocolError); ok {
|
||||||
|
return int(verr.frame.Header().version.version())
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
max, err := strconv.Atoi(matches[0][1])
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
|
||||||
|
hosts = shuffleHosts(hosts)
|
||||||
|
|
||||||
|
connCfg := *c.session.connCfg
|
||||||
|
connCfg.ProtoVersion = 4 // TODO: define maxProtocol
|
||||||
|
|
||||||
|
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
|
||||||
|
// we should never get here, but if we do it means we connected to a
|
||||||
|
// host successfully which means our attempted protocol version worked
|
||||||
|
if !closed {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, host := range hosts {
|
||||||
|
var conn *Conn
|
||||||
|
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
|
||||||
|
if conn != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return connCfg.ProtoVersion, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if proto := parseProtocolFromError(err); proto > 0 {
|
||||||
|
return proto, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) connect(hosts []*HostInfo) error {
|
||||||
|
if len(hosts) == 0 {
|
||||||
|
return errors.New("control: no endpoints specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := c.shuffleDial(hosts)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("control: unable to connect to initial hosts: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.setupConn(conn); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("control: unable to setup connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// we could fetch the initial ring here and update initial host data. So that
|
||||||
|
// when we return from here we have a ring topology ready to go.
|
||||||
|
|
||||||
|
go c.heartBeat()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type connHost struct {
|
||||||
|
conn *Conn
|
||||||
|
host *HostInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) setupConn(conn *Conn) error {
|
||||||
|
if err := c.registerEvents(conn); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(zariel): do we need to fetch host info everytime
|
||||||
|
// the control conn connects? Surely we have it cached?
|
||||||
|
host, err := conn.localHostInfo(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := &connHost{
|
||||||
|
conn: conn,
|
||||||
|
host: host,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn.Store(ch)
|
||||||
|
c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) registerEvents(conn *Conn) error {
|
||||||
|
var events []string
|
||||||
|
|
||||||
|
if !c.session.cfg.Events.DisableTopologyEvents {
|
||||||
|
events = append(events, "TOPOLOGY_CHANGE")
|
||||||
|
}
|
||||||
|
if !c.session.cfg.Events.DisableNodeStatusEvents {
|
||||||
|
events = append(events, "STATUS_CHANGE")
|
||||||
|
}
|
||||||
|
if !c.session.cfg.Events.DisableSchemaEvents {
|
||||||
|
events = append(events, "SCHEMA_CHANGE")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
framer, err := conn.exec(context.Background(),
|
||||||
|
&writeRegisterFrame{
|
||||||
|
events: events,
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
frame, err := framer.parseFrame()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if _, ok := frame.(*readyFrame); !ok {
|
||||||
|
return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) reconnect(refreshring bool) {
|
||||||
|
if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer atomic.StoreInt32(&c.reconnecting, 0)
|
||||||
|
// TODO: simplify this function, use session.ring to get hosts instead of the
|
||||||
|
// connection pool
|
||||||
|
|
||||||
|
var host *HostInfo
|
||||||
|
ch := c.getConn()
|
||||||
|
if ch != nil {
|
||||||
|
host = ch.host
|
||||||
|
ch.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var newConn *Conn
|
||||||
|
if host != nil {
|
||||||
|
// try to connect to the old host
|
||||||
|
conn, err := c.session.connect(c.session.ctx, host, c)
|
||||||
|
if err != nil {
|
||||||
|
// host is dead
|
||||||
|
// TODO: this is replicated in a few places
|
||||||
|
if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
|
||||||
|
c.session.handleNodeDown(host.ConnectAddress(), host.Port())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
newConn = conn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: should have our own round-robin for hosts so that we can try each
|
||||||
|
// in succession and guarantee that we get a different host each time.
|
||||||
|
if newConn == nil {
|
||||||
|
host := c.session.ring.rrHost()
|
||||||
|
if host == nil {
|
||||||
|
c.connect(c.session.ring.endpoints)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
newConn, err = c.session.connect(c.session.ctx, host, c)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: add log handler for things like this
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.setupConn(newConn); err != nil {
|
||||||
|
newConn.Close()
|
||||||
|
Logger.Printf("gocql: control unable to register events: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshring {
|
||||||
|
c.session.hostSource.refreshRing()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
|
||||||
|
if !closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oldConn := c.getConn()
|
||||||
|
|
||||||
|
// If connection has long gone, and not been attempted for awhile,
|
||||||
|
// it's possible to have oldConn as nil here (#1297).
|
||||||
|
if oldConn != nil && oldConn.conn != conn {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.reconnect(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) getConn() *connHost {
|
||||||
|
return c.conn.Load().(*connHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
|
||||||
|
ch := c.getConn()
|
||||||
|
if ch == nil {
|
||||||
|
return nil, errNoControl
|
||||||
|
}
|
||||||
|
|
||||||
|
framer, err := ch.conn.exec(context.Background(), w, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return framer.parseFrame()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
|
||||||
|
const maxConnectAttempts = 5
|
||||||
|
connectAttempts := 0
|
||||||
|
|
||||||
|
for i := 0; i < maxConnectAttempts; i++ {
|
||||||
|
ch := c.getConn()
|
||||||
|
if ch == nil {
|
||||||
|
if connectAttempts > maxConnectAttempts {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
connectAttempts++
|
||||||
|
|
||||||
|
c.reconnect(false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn(ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Iter{err: errNoControl}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
|
||||||
|
return c.withConnHost(func(ch *connHost) *Iter {
|
||||||
|
return fn(ch.conn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// query will return nil if the connection is closed or nil
|
||||||
|
func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
|
||||||
|
q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
|
||||||
|
|
||||||
|
for {
|
||||||
|
iter = c.withConn(func(conn *Conn) *Iter {
|
||||||
|
return conn.executeQuery(context.TODO(), q)
|
||||||
|
})
|
||||||
|
|
||||||
|
if gocqlDebug && iter.err != nil {
|
||||||
|
Logger.Printf("control: error executing %q: %v\n", statement, iter.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q.AddAttempts(1, c.getConn().host)
|
||||||
|
if iter.err == nil || !c.retry.Attempt(q) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) awaitSchemaAgreement() error {
|
||||||
|
return c.withConn(func(conn *Conn) *Iter {
|
||||||
|
return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
|
||||||
|
}).err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlConn) close() {
|
||||||
|
if atomic.CompareAndSwapInt32(&c.started, 1, -1) {
|
||||||
|
c.quit <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := c.getConn()
|
||||||
|
if ch != nil {
|
||||||
|
ch.conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errNoControl = errors.New("gocql: no control connection available")
|
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
type Duration struct {
|
||||||
|
Months int32
|
||||||
|
Days int32
|
||||||
|
Nanoseconds int64
|
||||||
|
}
|
@ -0,0 +1,5 @@
|
|||||||
|
// +build !gocql_debug
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
const gocqlDebug = false
|
@ -0,0 +1,5 @@
|
|||||||
|
// +build gocql_debug
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
const gocqlDebug = true
|
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright (c) 2012-2015 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package gocql implements a fast and robust Cassandra driver for the
|
||||||
|
// Go programming language.
|
||||||
|
package gocql // import "github.com/gocql/gocql"
|
||||||
|
|
||||||
|
// TODO(tux21b): write more docs.
|
@ -0,0 +1,125 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
const (
|
||||||
|
errServer = 0x0000
|
||||||
|
errProtocol = 0x000A
|
||||||
|
errCredentials = 0x0100
|
||||||
|
errUnavailable = 0x1000
|
||||||
|
errOverloaded = 0x1001
|
||||||
|
errBootstrapping = 0x1002
|
||||||
|
errTruncate = 0x1003
|
||||||
|
errWriteTimeout = 0x1100
|
||||||
|
errReadTimeout = 0x1200
|
||||||
|
errReadFailure = 0x1300
|
||||||
|
errFunctionFailure = 0x1400
|
||||||
|
errWriteFailure = 0x1500
|
||||||
|
errCDCWriteFailure = 0x1600
|
||||||
|
errSyntax = 0x2000
|
||||||
|
errUnauthorized = 0x2100
|
||||||
|
errInvalid = 0x2200
|
||||||
|
errConfig = 0x2300
|
||||||
|
errAlreadyExists = 0x2400
|
||||||
|
errUnprepared = 0x2500
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestError interface {
|
||||||
|
Code() int
|
||||||
|
Message() string
|
||||||
|
Error() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorFrame struct {
|
||||||
|
frameHeader
|
||||||
|
|
||||||
|
code int
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errorFrame) Code() int {
|
||||||
|
return e.code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errorFrame) Message() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errorFrame) Error() string {
|
||||||
|
return e.Message()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errorFrame) String() string {
|
||||||
|
return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrUnavailable struct {
|
||||||
|
errorFrame
|
||||||
|
Consistency Consistency
|
||||||
|
Required int
|
||||||
|
Alive int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RequestErrUnavailable) String() string {
|
||||||
|
return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorMap map[string]uint16
|
||||||
|
|
||||||
|
type RequestErrWriteTimeout struct {
|
||||||
|
errorFrame
|
||||||
|
Consistency Consistency
|
||||||
|
Received int
|
||||||
|
BlockFor int
|
||||||
|
WriteType string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrWriteFailure struct {
|
||||||
|
errorFrame
|
||||||
|
Consistency Consistency
|
||||||
|
Received int
|
||||||
|
BlockFor int
|
||||||
|
NumFailures int
|
||||||
|
WriteType string
|
||||||
|
ErrorMap ErrorMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrCDCWriteFailure struct {
|
||||||
|
errorFrame
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrReadTimeout struct {
|
||||||
|
errorFrame
|
||||||
|
Consistency Consistency
|
||||||
|
Received int
|
||||||
|
BlockFor int
|
||||||
|
DataPresent byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrAlreadyExists struct {
|
||||||
|
errorFrame
|
||||||
|
Keyspace string
|
||||||
|
Table string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrUnprepared struct {
|
||||||
|
errorFrame
|
||||||
|
StatementId []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrReadFailure struct {
|
||||||
|
errorFrame
|
||||||
|
Consistency Consistency
|
||||||
|
Received int
|
||||||
|
BlockFor int
|
||||||
|
NumFailures int
|
||||||
|
DataPresent bool
|
||||||
|
ErrorMap ErrorMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestErrFunctionFailure struct {
|
||||||
|
errorFrame
|
||||||
|
Keyspace string
|
||||||
|
Function string
|
||||||
|
ArgTypes []string
|
||||||
|
}
|
@ -0,0 +1,293 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type eventDebouncer struct {
|
||||||
|
name string
|
||||||
|
timer *time.Timer
|
||||||
|
mu sync.Mutex
|
||||||
|
events []frame
|
||||||
|
|
||||||
|
callback func([]frame)
|
||||||
|
quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEventDebouncer(name string, eventHandler func([]frame)) *eventDebouncer {
|
||||||
|
e := &eventDebouncer{
|
||||||
|
name: name,
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
timer: time.NewTimer(eventDebounceTime),
|
||||||
|
callback: eventHandler,
|
||||||
|
}
|
||||||
|
e.timer.Stop()
|
||||||
|
go e.flusher()
|
||||||
|
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *eventDebouncer) stop() {
|
||||||
|
e.quit <- struct{}{} // sync with flusher
|
||||||
|
close(e.quit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *eventDebouncer) flusher() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-e.timer.C:
|
||||||
|
e.mu.Lock()
|
||||||
|
e.flush()
|
||||||
|
e.mu.Unlock()
|
||||||
|
case <-e.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
eventBufferSize = 1000
|
||||||
|
eventDebounceTime = 1 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// flush must be called with mu locked
|
||||||
|
func (e *eventDebouncer) flush() {
|
||||||
|
if len(e.events) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the flush interval is faster than the callback then we will end up calling
|
||||||
|
// the callback multiple times, probably a bad idea. In this case we could drop
|
||||||
|
// frames?
|
||||||
|
go e.callback(e.events)
|
||||||
|
e.events = make([]frame, 0, eventBufferSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *eventDebouncer) debounce(frame frame) {
|
||||||
|
e.mu.Lock()
|
||||||
|
e.timer.Reset(eventDebounceTime)
|
||||||
|
|
||||||
|
// TODO: probably need a warning to track if this threshold is too low
|
||||||
|
if len(e.events) < eventBufferSize {
|
||||||
|
e.events = append(e.events, frame)
|
||||||
|
} else {
|
||||||
|
Logger.Printf("%s: buffer full, dropping event frame: %s", e.name, frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleEvent(framer *framer) {
|
||||||
|
frame, err := framer.parseFrame()
|
||||||
|
if err != nil {
|
||||||
|
// TODO: logger
|
||||||
|
Logger.Printf("gocql: unable to parse event frame: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: handling frame: %v\n", frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch f := frame.(type) {
|
||||||
|
case *schemaChangeKeyspace, *schemaChangeFunction,
|
||||||
|
*schemaChangeTable, *schemaChangeAggregate, *schemaChangeType:
|
||||||
|
|
||||||
|
s.schemaEvents.debounce(frame)
|
||||||
|
case *topologyChangeEventFrame, *statusChangeEventFrame:
|
||||||
|
s.nodeEvents.debounce(frame)
|
||||||
|
default:
|
||||||
|
Logger.Printf("gocql: invalid event frame (%T): %v\n", f, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleSchemaEvent(frames []frame) {
|
||||||
|
// TODO: debounce events
|
||||||
|
for _, frame := range frames {
|
||||||
|
switch f := frame.(type) {
|
||||||
|
case *schemaChangeKeyspace:
|
||||||
|
s.schemaDescriber.clearSchema(f.keyspace)
|
||||||
|
s.handleKeyspaceChange(f.keyspace, f.change)
|
||||||
|
case *schemaChangeTable:
|
||||||
|
s.schemaDescriber.clearSchema(f.keyspace)
|
||||||
|
case *schemaChangeAggregate:
|
||||||
|
s.schemaDescriber.clearSchema(f.keyspace)
|
||||||
|
case *schemaChangeFunction:
|
||||||
|
s.schemaDescriber.clearSchema(f.keyspace)
|
||||||
|
case *schemaChangeType:
|
||||||
|
s.schemaDescriber.clearSchema(f.keyspace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleKeyspaceChange(keyspace, change string) {
|
||||||
|
s.control.awaitSchemaAgreement()
|
||||||
|
s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleNodeEvent(frames []frame) {
|
||||||
|
type nodeEvent struct {
|
||||||
|
change string
|
||||||
|
host net.IP
|
||||||
|
port int
|
||||||
|
}
|
||||||
|
|
||||||
|
events := make(map[string]*nodeEvent)
|
||||||
|
|
||||||
|
for _, frame := range frames {
|
||||||
|
// TODO: can we be sure the order of events in the buffer is correct?
|
||||||
|
switch f := frame.(type) {
|
||||||
|
case *topologyChangeEventFrame:
|
||||||
|
event, ok := events[f.host.String()]
|
||||||
|
if !ok {
|
||||||
|
event = &nodeEvent{change: f.change, host: f.host, port: f.port}
|
||||||
|
events[f.host.String()] = event
|
||||||
|
}
|
||||||
|
event.change = f.change
|
||||||
|
|
||||||
|
case *statusChangeEventFrame:
|
||||||
|
event, ok := events[f.host.String()]
|
||||||
|
if !ok {
|
||||||
|
event = &nodeEvent{change: f.change, host: f.host, port: f.port}
|
||||||
|
events[f.host.String()] = event
|
||||||
|
}
|
||||||
|
event.change = f.change
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range events {
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: dispatching event: %+v\n", f)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch f.change {
|
||||||
|
case "NEW_NODE":
|
||||||
|
s.handleNewNode(f.host, f.port, true)
|
||||||
|
case "REMOVED_NODE":
|
||||||
|
s.handleRemovedNode(f.host, f.port)
|
||||||
|
case "MOVED_NODE":
|
||||||
|
// java-driver handles this, not mentioned in the spec
|
||||||
|
// TODO(zariel): refresh token map
|
||||||
|
case "UP":
|
||||||
|
s.handleNodeUp(f.host, f.port, true)
|
||||||
|
case "DOWN":
|
||||||
|
s.handleNodeDown(f.host, f.port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) addNewNode(host *HostInfo) {
|
||||||
|
if s.cfg.filterHost(host) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host.setState(NodeUp)
|
||||||
|
s.pool.addHost(host)
|
||||||
|
s.policy.AddHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, port = s.cfg.translateAddressPort(ip, port)
|
||||||
|
|
||||||
|
// Get host info and apply any filters to the host
|
||||||
|
hostInfo, err := s.hostSource.getHostInfo(ip, port)
|
||||||
|
if err != nil {
|
||||||
|
Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
|
||||||
|
return
|
||||||
|
} else if hostInfo == nil {
|
||||||
|
// If hostInfo is nil, this host was filtered out by cfg.HostFilter
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if t := hostInfo.Version().nodeUpDelay(); t > 0 && waitForBinary {
|
||||||
|
time.Sleep(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// should this handle token moving?
|
||||||
|
hostInfo = s.ring.addOrUpdate(hostInfo)
|
||||||
|
|
||||||
|
s.addNewNode(hostInfo)
|
||||||
|
|
||||||
|
if s.control != nil && !s.cfg.IgnorePeerAddr {
|
||||||
|
// TODO(zariel): debounce ring refresh
|
||||||
|
s.hostSource.refreshRing()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleRemovedNode(ip net.IP, port int) {
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, port = s.cfg.translateAddressPort(ip, port)
|
||||||
|
|
||||||
|
// we remove all nodes but only add ones which pass the filter
|
||||||
|
host := s.ring.getHost(ip)
|
||||||
|
if host == nil {
|
||||||
|
host = &HostInfo{connectAddress: ip, port: port}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host.setState(NodeDown)
|
||||||
|
s.policy.RemoveHost(host)
|
||||||
|
s.pool.removeHost(ip)
|
||||||
|
s.ring.removeHost(ip)
|
||||||
|
|
||||||
|
if !s.cfg.IgnorePeerAddr {
|
||||||
|
s.hostSource.refreshRing()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) {
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, _ := s.cfg.translateAddressPort(eventIp, eventPort)
|
||||||
|
|
||||||
|
host := s.ring.getHost(ip)
|
||||||
|
if host == nil {
|
||||||
|
// TODO(zariel): avoid the need to translate twice in this
|
||||||
|
// case
|
||||||
|
s.handleNewNode(eventIp, eventPort, waitForBinary)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
|
||||||
|
time.Sleep(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.addNewNode(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleNodeDown(ip net.IP, port int) {
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
host := s.ring.getHost(ip)
|
||||||
|
if host == nil {
|
||||||
|
host = &HostInfo{connectAddress: ip, port: port}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host.setState(NodeDown)
|
||||||
|
s.policy.HostDown(host)
|
||||||
|
s.pool.hostDown(ip)
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// HostFilter interface is used when a host is discovered via server sent events.
|
||||||
|
type HostFilter interface {
|
||||||
|
// Called when a new host is discovered, returning true will cause the host
|
||||||
|
// to be added to the pools.
|
||||||
|
Accept(host *HostInfo) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
|
||||||
|
type HostFilterFunc func(host *HostInfo) bool
|
||||||
|
|
||||||
|
func (fn HostFilterFunc) Accept(host *HostInfo) bool {
|
||||||
|
return fn(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptAllFilter will accept all hosts
|
||||||
|
func AcceptAllFilter() HostFilter {
|
||||||
|
return HostFilterFunc(func(host *HostInfo) bool {
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func DenyAllFilter() HostFilter {
|
||||||
|
return HostFilterFunc(func(host *HostInfo) bool {
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DataCentreHostFilter filters all hosts such that they are in the same data centre
|
||||||
|
// as the supplied data centre.
|
||||||
|
func DataCentreHostFilter(dataCentre string) HostFilter {
|
||||||
|
return HostFilterFunc(func(host *HostInfo) bool {
|
||||||
|
return host.DataCenter() == dataCentre
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WhiteListHostFilter filters incoming hosts by checking that their address is
|
||||||
|
// in the initial hosts whitelist.
|
||||||
|
func WhiteListHostFilter(hosts ...string) HostFilter {
|
||||||
|
hostInfos, err := addrsToHosts(hosts, 9042)
|
||||||
|
if err != nil {
|
||||||
|
// dont want to panic here, but rather not break the API
|
||||||
|
panic(fmt.Errorf("unable to lookup host info from address: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
m := make(map[string]bool, len(hostInfos))
|
||||||
|
for _, host := range hostInfos {
|
||||||
|
m[host.ConnectAddress().String()] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return HostFilterFunc(func(host *HostInfo) bool {
|
||||||
|
return m[host.ConnectAddress().String()]
|
||||||
|
})
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,33 @@
|
|||||||
|
// +build gofuzz
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
import "bytes"
|
||||||
|
|
||||||
|
func Fuzz(data []byte) int {
|
||||||
|
var bw bytes.Buffer
|
||||||
|
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
|
||||||
|
head, err := readHeader(r, make([]byte, 9))
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
framer := newFramer(r, &bw, nil, byte(head.version))
|
||||||
|
err = framer.readFrame(&head)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
frame, err := framer.parseFrame()
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if frame != nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return 2
|
||||||
|
}
|
@ -0,0 +1,13 @@
|
|||||||
|
module github.com/gocql/gocql
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect
|
||||||
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect
|
||||||
|
github.com/golang/snappy v0.0.0-20170215233205-553a64147049
|
||||||
|
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed
|
||||||
|
github.com/kr/pretty v0.1.0 // indirect
|
||||||
|
github.com/stretchr/testify v1.3.0 // indirect
|
||||||
|
gopkg.in/inf.v0 v0.9.1
|
||||||
|
)
|
||||||
|
|
||||||
|
go 1.13
|
@ -0,0 +1,22 @@
|
|||||||
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
||||||
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||||
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||||
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||||
|
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
|
||||||
|
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
|
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||||
|
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
||||||
|
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||||
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||||
|
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
@ -0,0 +1,432 @@
|
|||||||
|
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/inf.v0"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RowData struct {
|
||||||
|
Columns []string
|
||||||
|
Values []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func goType(t TypeInfo) reflect.Type {
|
||||||
|
switch t.Type() {
|
||||||
|
case TypeVarchar, TypeAscii, TypeInet, TypeText:
|
||||||
|
return reflect.TypeOf(*new(string))
|
||||||
|
case TypeBigInt, TypeCounter:
|
||||||
|
return reflect.TypeOf(*new(int64))
|
||||||
|
case TypeTime:
|
||||||
|
return reflect.TypeOf(*new(time.Duration))
|
||||||
|
case TypeTimestamp:
|
||||||
|
return reflect.TypeOf(*new(time.Time))
|
||||||
|
case TypeBlob:
|
||||||
|
return reflect.TypeOf(*new([]byte))
|
||||||
|
case TypeBoolean:
|
||||||
|
return reflect.TypeOf(*new(bool))
|
||||||
|
case TypeFloat:
|
||||||
|
return reflect.TypeOf(*new(float32))
|
||||||
|
case TypeDouble:
|
||||||
|
return reflect.TypeOf(*new(float64))
|
||||||
|
case TypeInt:
|
||||||
|
return reflect.TypeOf(*new(int))
|
||||||
|
case TypeSmallInt:
|
||||||
|
return reflect.TypeOf(*new(int16))
|
||||||
|
case TypeTinyInt:
|
||||||
|
return reflect.TypeOf(*new(int8))
|
||||||
|
case TypeDecimal:
|
||||||
|
return reflect.TypeOf(*new(*inf.Dec))
|
||||||
|
case TypeUUID, TypeTimeUUID:
|
||||||
|
return reflect.TypeOf(*new(UUID))
|
||||||
|
case TypeList, TypeSet:
|
||||||
|
return reflect.SliceOf(goType(t.(CollectionType).Elem))
|
||||||
|
case TypeMap:
|
||||||
|
return reflect.MapOf(goType(t.(CollectionType).Key), goType(t.(CollectionType).Elem))
|
||||||
|
case TypeVarint:
|
||||||
|
return reflect.TypeOf(*new(*big.Int))
|
||||||
|
case TypeTuple:
|
||||||
|
// what can we do here? all there is to do is to make a list of interface{}
|
||||||
|
tuple := t.(TupleTypeInfo)
|
||||||
|
return reflect.TypeOf(make([]interface{}, len(tuple.Elems)))
|
||||||
|
case TypeUDT:
|
||||||
|
return reflect.TypeOf(make(map[string]interface{}))
|
||||||
|
case TypeDate:
|
||||||
|
return reflect.TypeOf(*new(time.Time))
|
||||||
|
case TypeDuration:
|
||||||
|
return reflect.TypeOf(*new(Duration))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dereference(i interface{}) interface{} {
|
||||||
|
return reflect.Indirect(reflect.ValueOf(i)).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCassandraBaseType(name string) Type {
|
||||||
|
switch name {
|
||||||
|
case "ascii":
|
||||||
|
return TypeAscii
|
||||||
|
case "bigint":
|
||||||
|
return TypeBigInt
|
||||||
|
case "blob":
|
||||||
|
return TypeBlob
|
||||||
|
case "boolean":
|
||||||
|
return TypeBoolean
|
||||||
|
case "counter":
|
||||||
|
return TypeCounter
|
||||||
|
case "decimal":
|
||||||
|
return TypeDecimal
|
||||||
|
case "double":
|
||||||
|
return TypeDouble
|
||||||
|
case "float":
|
||||||
|
return TypeFloat
|
||||||
|
case "int":
|
||||||
|
return TypeInt
|
||||||
|
case "tinyint":
|
||||||
|
return TypeTinyInt
|
||||||
|
case "time":
|
||||||
|
return TypeTime
|
||||||
|
case "timestamp":
|
||||||
|
return TypeTimestamp
|
||||||
|
case "uuid":
|
||||||
|
return TypeUUID
|
||||||
|
case "varchar":
|
||||||
|
return TypeVarchar
|
||||||
|
case "text":
|
||||||
|
return TypeText
|
||||||
|
case "varint":
|
||||||
|
return TypeVarint
|
||||||
|
case "timeuuid":
|
||||||
|
return TypeTimeUUID
|
||||||
|
case "inet":
|
||||||
|
return TypeInet
|
||||||
|
case "MapType":
|
||||||
|
return TypeMap
|
||||||
|
case "ListType":
|
||||||
|
return TypeList
|
||||||
|
case "SetType":
|
||||||
|
return TypeSet
|
||||||
|
case "TupleType":
|
||||||
|
return TypeTuple
|
||||||
|
default:
|
||||||
|
return TypeCustom
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCassandraType(name string) TypeInfo {
|
||||||
|
if strings.HasPrefix(name, "frozen<") {
|
||||||
|
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"))
|
||||||
|
} else if strings.HasPrefix(name, "set<") {
|
||||||
|
return CollectionType{
|
||||||
|
NativeType: NativeType{typ: TypeSet},
|
||||||
|
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<")),
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(name, "list<") {
|
||||||
|
return CollectionType{
|
||||||
|
NativeType: NativeType{typ: TypeList},
|
||||||
|
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<")),
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(name, "map<") {
|
||||||
|
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
|
||||||
|
if len(names) != 2 {
|
||||||
|
Logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
|
||||||
|
return NativeType{
|
||||||
|
typ: TypeCustom,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return CollectionType{
|
||||||
|
NativeType: NativeType{typ: TypeMap},
|
||||||
|
Key: getCassandraType(names[0]),
|
||||||
|
Elem: getCassandraType(names[1]),
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(name, "tuple<") {
|
||||||
|
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
|
||||||
|
types := make([]TypeInfo, len(names))
|
||||||
|
|
||||||
|
for i, name := range names {
|
||||||
|
types[i] = getCassandraType(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return TupleTypeInfo{
|
||||||
|
NativeType: NativeType{typ: TypeTuple},
|
||||||
|
Elems: types,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return NativeType{
|
||||||
|
typ: getCassandraBaseType(name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitCompositeTypes(name string) []string {
|
||||||
|
if !strings.Contains(name, "<") {
|
||||||
|
return strings.Split(name, ", ")
|
||||||
|
}
|
||||||
|
var parts []string
|
||||||
|
lessCount := 0
|
||||||
|
segment := ""
|
||||||
|
for _, char := range name {
|
||||||
|
if char == ',' && lessCount == 0 {
|
||||||
|
if segment != "" {
|
||||||
|
parts = append(parts, strings.TrimSpace(segment))
|
||||||
|
}
|
||||||
|
segment = ""
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
segment += string(char)
|
||||||
|
if char == '<' {
|
||||||
|
lessCount++
|
||||||
|
} else if char == '>' {
|
||||||
|
lessCount--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if segment != "" {
|
||||||
|
parts = append(parts, strings.TrimSpace(segment))
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
func apacheToCassandraType(t string) string {
|
||||||
|
t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
|
||||||
|
t = strings.Replace(t, "(", "<", -1)
|
||||||
|
t = strings.Replace(t, ")", ">", -1)
|
||||||
|
types := strings.FieldsFunc(t, func(r rune) bool {
|
||||||
|
return r == '<' || r == '>' || r == ','
|
||||||
|
})
|
||||||
|
for _, typ := range types {
|
||||||
|
t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1)
|
||||||
|
}
|
||||||
|
// This is done so it exactly matches what Cassandra returns
|
||||||
|
return strings.Replace(t, ",", ", ", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getApacheCassandraType(class string) Type {
|
||||||
|
switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
|
||||||
|
case "AsciiType":
|
||||||
|
return TypeAscii
|
||||||
|
case "LongType":
|
||||||
|
return TypeBigInt
|
||||||
|
case "BytesType":
|
||||||
|
return TypeBlob
|
||||||
|
case "BooleanType":
|
||||||
|
return TypeBoolean
|
||||||
|
case "CounterColumnType":
|
||||||
|
return TypeCounter
|
||||||
|
case "DecimalType":
|
||||||
|
return TypeDecimal
|
||||||
|
case "DoubleType":
|
||||||
|
return TypeDouble
|
||||||
|
case "FloatType":
|
||||||
|
return TypeFloat
|
||||||
|
case "Int32Type":
|
||||||
|
return TypeInt
|
||||||
|
case "ShortType":
|
||||||
|
return TypeSmallInt
|
||||||
|
case "ByteType":
|
||||||
|
return TypeTinyInt
|
||||||
|
case "TimeType":
|
||||||
|
return TypeTime
|
||||||
|
case "DateType", "TimestampType":
|
||||||
|
return TypeTimestamp
|
||||||
|
case "UUIDType", "LexicalUUIDType":
|
||||||
|
return TypeUUID
|
||||||
|
case "UTF8Type":
|
||||||
|
return TypeVarchar
|
||||||
|
case "IntegerType":
|
||||||
|
return TypeVarint
|
||||||
|
case "TimeUUIDType":
|
||||||
|
return TypeTimeUUID
|
||||||
|
case "InetAddressType":
|
||||||
|
return TypeInet
|
||||||
|
case "MapType":
|
||||||
|
return TypeMap
|
||||||
|
case "ListType":
|
||||||
|
return TypeList
|
||||||
|
case "SetType":
|
||||||
|
return TypeSet
|
||||||
|
case "TupleType":
|
||||||
|
return TypeTuple
|
||||||
|
case "DurationType":
|
||||||
|
return TypeDuration
|
||||||
|
default:
|
||||||
|
return TypeCustom
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeCanBeNull(typ TypeInfo) bool {
|
||||||
|
switch typ.(type) {
|
||||||
|
case CollectionType, UDTTypeInfo, TupleTypeInfo:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RowData) rowMap(m map[string]interface{}) {
|
||||||
|
for i, column := range r.Columns {
|
||||||
|
val := dereference(r.Values[i])
|
||||||
|
if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice {
|
||||||
|
valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap())
|
||||||
|
reflect.Copy(valCopy, valVal)
|
||||||
|
m[column] = valCopy.Interface()
|
||||||
|
} else {
|
||||||
|
m[column] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TupeColumnName will return the column name of a tuple value in a column named
|
||||||
|
// c at index n. It should be used if a specific element within a tuple is needed
|
||||||
|
// to be extracted from a map returned from SliceMap or MapScan.
|
||||||
|
func TupleColumnName(c string, n int) string {
|
||||||
|
return fmt.Sprintf("%s[%d]", c, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (iter *Iter) RowData() (RowData, error) {
|
||||||
|
if iter.err != nil {
|
||||||
|
return RowData{}, iter.err
|
||||||
|
}
|
||||||
|
|
||||||
|
columns := make([]string, 0, len(iter.Columns()))
|
||||||
|
values := make([]interface{}, 0, len(iter.Columns()))
|
||||||
|
|
||||||
|
for _, column := range iter.Columns() {
|
||||||
|
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
|
||||||
|
val := column.TypeInfo.New()
|
||||||
|
columns = append(columns, column.Name)
|
||||||
|
values = append(values, val)
|
||||||
|
} else {
|
||||||
|
for i, elem := range c.Elems {
|
||||||
|
columns = append(columns, TupleColumnName(column.Name, i))
|
||||||
|
values = append(values, elem.New())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rowData := RowData{
|
||||||
|
Columns: columns,
|
||||||
|
Values: values,
|
||||||
|
}
|
||||||
|
|
||||||
|
return rowData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(zariel): is it worth exporting this?
|
||||||
|
func (iter *Iter) rowMap() (map[string]interface{}, error) {
|
||||||
|
if iter.err != nil {
|
||||||
|
return nil, iter.err
|
||||||
|
}
|
||||||
|
|
||||||
|
rowData, _ := iter.RowData()
|
||||||
|
iter.Scan(rowData.Values...)
|
||||||
|
m := make(map[string]interface{}, len(rowData.Columns))
|
||||||
|
rowData.rowMap(m)
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SliceMap is a helper function to make the API easier to use
|
||||||
|
// returns the data from the query in the form of []map[string]interface{}
|
||||||
|
func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
|
||||||
|
if iter.err != nil {
|
||||||
|
return nil, iter.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not checking for the error because we just did
|
||||||
|
rowData, _ := iter.RowData()
|
||||||
|
dataToReturn := make([]map[string]interface{}, 0)
|
||||||
|
for iter.Scan(rowData.Values...) {
|
||||||
|
m := make(map[string]interface{}, len(rowData.Columns))
|
||||||
|
rowData.rowMap(m)
|
||||||
|
dataToReturn = append(dataToReturn, m)
|
||||||
|
}
|
||||||
|
if iter.err != nil {
|
||||||
|
return nil, iter.err
|
||||||
|
}
|
||||||
|
return dataToReturn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapScan takes a map[string]interface{} and populates it with a row
|
||||||
|
// that is returned from cassandra.
|
||||||
|
//
|
||||||
|
// Each call to MapScan() must be called with a new map object.
|
||||||
|
// During the call to MapScan() any pointers in the existing map
|
||||||
|
// are replaced with non pointer types before the call returns
|
||||||
|
//
|
||||||
|
// iter := session.Query(`SELECT * FROM mytable`).Iter()
|
||||||
|
// for {
|
||||||
|
// // New map each iteration
|
||||||
|
// row = make(map[string]interface{})
|
||||||
|
// if !iter.MapScan(row) {
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// // Do things with row
|
||||||
|
// if fullname, ok := row["fullname"]; ok {
|
||||||
|
// fmt.Printf("Full Name: %s\n", fullname)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// You can also pass pointers in the map before each call
|
||||||
|
//
|
||||||
|
// var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces
|
||||||
|
// var address net.IP
|
||||||
|
// var age int
|
||||||
|
// iter := session.Query(`SELECT * FROM scan_map_table`).Iter()
|
||||||
|
// for {
|
||||||
|
// // New map each iteration
|
||||||
|
// row := map[string]interface{}{
|
||||||
|
// "fullname": &fullName,
|
||||||
|
// "age": &age,
|
||||||
|
// "address": &address,
|
||||||
|
// }
|
||||||
|
// if !iter.MapScan(row) {
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address)
|
||||||
|
// }
|
||||||
|
func (iter *Iter) MapScan(m map[string]interface{}) bool {
|
||||||
|
if iter.err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not checking for the error because we just did
|
||||||
|
rowData, _ := iter.RowData()
|
||||||
|
|
||||||
|
for i, col := range rowData.Columns {
|
||||||
|
if dest, ok := m[col]; ok {
|
||||||
|
rowData.Values[i] = dest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if iter.Scan(rowData.Values...) {
|
||||||
|
rowData.rowMap(m)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyBytes(p []byte) []byte {
|
||||||
|
b := make([]byte, len(p))
|
||||||
|
copy(b, p)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
var failDNS = false
|
||||||
|
|
||||||
|
func LookupIP(host string) ([]net.IP, error) {
|
||||||
|
if failDNS {
|
||||||
|
return nil, &net.DNSError{}
|
||||||
|
}
|
||||||
|
return net.LookupIP(host)
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,715 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nodeState int32
|
||||||
|
|
||||||
|
func (n nodeState) String() string {
|
||||||
|
if n == NodeUp {
|
||||||
|
return "UP"
|
||||||
|
} else if n == NodeDown {
|
||||||
|
return "DOWN"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("UNKNOWN_%d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
NodeUp nodeState = iota
|
||||||
|
NodeDown
|
||||||
|
)
|
||||||
|
|
||||||
|
type cassVersion struct {
|
||||||
|
Major, Minor, Patch int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cassVersion) Set(v string) error {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.UnmarshalCQL(nil, []byte(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
|
||||||
|
return c.unmarshal(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cassVersion) unmarshal(data []byte) error {
|
||||||
|
version := strings.TrimSuffix(string(data), "-SNAPSHOT")
|
||||||
|
version = strings.TrimPrefix(version, "v")
|
||||||
|
v := strings.Split(version, ".")
|
||||||
|
|
||||||
|
if len(v) < 2 {
|
||||||
|
return fmt.Errorf("invalid version string: %s", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c.Major, err = strconv.Atoi(v[0])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid major version %v: %v", v[0], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Minor, err = strconv.Atoi(v[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid minor version %v: %v", v[1], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(v) > 2 {
|
||||||
|
c.Patch, err = strconv.Atoi(v[2])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid patch version %v: %v", v[2], err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cassVersion) Before(major, minor, patch int) bool {
|
||||||
|
// We're comparing us (cassVersion) with the provided version (major, minor, patch)
|
||||||
|
// We return true if our version is lower (comes before) than the provided one.
|
||||||
|
if c.Major < major {
|
||||||
|
return true
|
||||||
|
} else if c.Major == major {
|
||||||
|
if c.Minor < minor {
|
||||||
|
return true
|
||||||
|
} else if c.Minor == minor && c.Patch < patch {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cassVersion) AtLeast(major, minor, patch int) bool {
|
||||||
|
return !c.Before(major, minor, patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cassVersion) String() string {
|
||||||
|
return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cassVersion) nodeUpDelay() time.Duration {
|
||||||
|
if c.Major >= 2 && c.Minor >= 2 {
|
||||||
|
// CASSANDRA-8236
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return 10 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
type HostInfo struct {
|
||||||
|
// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
|
||||||
|
// that we are thread safe use a mutex to access all fields.
|
||||||
|
mu sync.RWMutex
|
||||||
|
hostname string
|
||||||
|
peer net.IP
|
||||||
|
broadcastAddress net.IP
|
||||||
|
listenAddress net.IP
|
||||||
|
rpcAddress net.IP
|
||||||
|
preferredIP net.IP
|
||||||
|
connectAddress net.IP
|
||||||
|
port int
|
||||||
|
dataCenter string
|
||||||
|
rack string
|
||||||
|
hostId string
|
||||||
|
workload string
|
||||||
|
graph bool
|
||||||
|
dseVersion string
|
||||||
|
partitioner string
|
||||||
|
clusterName string
|
||||||
|
version cassVersion
|
||||||
|
state nodeState
|
||||||
|
schemaVersion string
|
||||||
|
tokens []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Equal(host *HostInfo) bool {
|
||||||
|
if h == host {
|
||||||
|
// prevent rlock reentry
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.ConnectAddress().Equal(host.ConnectAddress())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Peer() net.IP {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.peer = peer
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) invalidConnectAddr() bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
addr, _ := h.connectAddressLocked()
|
||||||
|
return !validIpAddr(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validIpAddr(addr net.IP) bool {
|
||||||
|
return addr != nil && !addr.IsUnspecified()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) connectAddressLocked() (net.IP, string) {
|
||||||
|
if validIpAddr(h.connectAddress) {
|
||||||
|
return h.connectAddress, "connect_address"
|
||||||
|
} else if validIpAddr(h.rpcAddress) {
|
||||||
|
return h.rpcAddress, "rpc_adress"
|
||||||
|
} else if validIpAddr(h.preferredIP) {
|
||||||
|
// where does perferred_ip get set?
|
||||||
|
return h.preferredIP, "preferred_ip"
|
||||||
|
} else if validIpAddr(h.broadcastAddress) {
|
||||||
|
return h.broadcastAddress, "broadcast_address"
|
||||||
|
} else if validIpAddr(h.peer) {
|
||||||
|
return h.peer, "peer"
|
||||||
|
}
|
||||||
|
return net.IPv4zero, "invalid"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the address that should be used to connect to the host.
|
||||||
|
// If you wish to override this, use an AddressTranslator or
|
||||||
|
// use a HostFilter to SetConnectAddress()
|
||||||
|
func (h *HostInfo) ConnectAddress() net.IP {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
|
||||||
|
// TODO(zariel): should this not be exported?
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.connectAddress = address
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) BroadcastAddress() net.IP {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.broadcastAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) ListenAddress() net.IP {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.listenAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) RPCAddress() net.IP {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.rpcAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) PreferredIP() net.IP {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.preferredIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) DataCenter() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
dc := h.dataCenter
|
||||||
|
h.mu.RUnlock()
|
||||||
|
return dc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.dataCenter = dataCenter
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Rack() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
rack := h.rack
|
||||||
|
h.mu.RUnlock()
|
||||||
|
return rack
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setRack(rack string) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.rack = rack
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) HostID() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.hostId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setHostID(hostID string) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.hostId = hostID
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) WorkLoad() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.workload
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Graph() bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.graph
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) DSEVersion() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.dseVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Partitioner() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.partitioner
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) ClusterName() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.clusterName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Version() cassVersion {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.version
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.version = cassVersion{major, minor, patch}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) State() nodeState {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setState(state nodeState) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.state = state
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Tokens() []string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setTokens(tokens []string) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.tokens = tokens
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) Port() int {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) setPort(port int) *HostInfo {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.port = port
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) update(from *HostInfo) {
|
||||||
|
if h == from {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
from.mu.RLock()
|
||||||
|
defer from.mu.RUnlock()
|
||||||
|
|
||||||
|
// autogenerated do not update
|
||||||
|
if h.peer == nil {
|
||||||
|
h.peer = from.peer
|
||||||
|
}
|
||||||
|
if h.broadcastAddress == nil {
|
||||||
|
h.broadcastAddress = from.broadcastAddress
|
||||||
|
}
|
||||||
|
if h.listenAddress == nil {
|
||||||
|
h.listenAddress = from.listenAddress
|
||||||
|
}
|
||||||
|
if h.rpcAddress == nil {
|
||||||
|
h.rpcAddress = from.rpcAddress
|
||||||
|
}
|
||||||
|
if h.preferredIP == nil {
|
||||||
|
h.preferredIP = from.preferredIP
|
||||||
|
}
|
||||||
|
if h.connectAddress == nil {
|
||||||
|
h.connectAddress = from.connectAddress
|
||||||
|
}
|
||||||
|
if h.port == 0 {
|
||||||
|
h.port = from.port
|
||||||
|
}
|
||||||
|
if h.dataCenter == "" {
|
||||||
|
h.dataCenter = from.dataCenter
|
||||||
|
}
|
||||||
|
if h.rack == "" {
|
||||||
|
h.rack = from.rack
|
||||||
|
}
|
||||||
|
if h.hostId == "" {
|
||||||
|
h.hostId = from.hostId
|
||||||
|
}
|
||||||
|
if h.workload == "" {
|
||||||
|
h.workload = from.workload
|
||||||
|
}
|
||||||
|
if h.dseVersion == "" {
|
||||||
|
h.dseVersion = from.dseVersion
|
||||||
|
}
|
||||||
|
if h.partitioner == "" {
|
||||||
|
h.partitioner = from.partitioner
|
||||||
|
}
|
||||||
|
if h.clusterName == "" {
|
||||||
|
h.clusterName = from.clusterName
|
||||||
|
}
|
||||||
|
if h.version == (cassVersion{}) {
|
||||||
|
h.version = from.version
|
||||||
|
}
|
||||||
|
if h.tokens == nil {
|
||||||
|
h.tokens = from.tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) IsUp() bool {
|
||||||
|
return h != nil && h.State() == NodeUp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) HostnameAndPort() string {
|
||||||
|
if h.hostname == "" {
|
||||||
|
h.hostname = h.ConnectAddress().String()
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(h.hostname, strconv.Itoa(h.port))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HostInfo) String() string {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
connectAddr, source := h.connectAddressLocked()
|
||||||
|
return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
|
||||||
|
"preferred_ip=%q connect_addr=%q connect_addr_source=%q "+
|
||||||
|
"port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]",
|
||||||
|
h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
|
||||||
|
connectAddr, source,
|
||||||
|
h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Polls system.peers at a specific interval to find new hosts
|
||||||
|
type ringDescriber struct {
|
||||||
|
session *Session
|
||||||
|
mu sync.Mutex
|
||||||
|
prevHosts []*HostInfo
|
||||||
|
prevPartitioner string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces
|
||||||
|
func checkSystemSchema(control *controlConn) (bool, error) {
|
||||||
|
iter := control.query("SELECT * FROM system_schema.keyspaces")
|
||||||
|
if err := iter.err; err != nil {
|
||||||
|
if errf, ok := err.(*errorFrame); ok {
|
||||||
|
if errf.code == errSyntax {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a map that represents a row from either system.local or system.peers
|
||||||
|
// return as much information as we can in *HostInfo
|
||||||
|
func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) {
|
||||||
|
const assertErrorMsg = "Assertion failed for %s"
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// Default to our connected port if the cluster doesn't have port information
|
||||||
|
for key, value := range row {
|
||||||
|
switch key {
|
||||||
|
case "data_center":
|
||||||
|
host.dataCenter, ok = value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "data_center")
|
||||||
|
}
|
||||||
|
case "rack":
|
||||||
|
host.rack, ok = value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "rack")
|
||||||
|
}
|
||||||
|
case "host_id":
|
||||||
|
hostId, ok := value.(UUID)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "host_id")
|
||||||
|
}
|
||||||
|
host.hostId = hostId.String()
|
||||||
|
case "release_version":
|
||||||
|
version, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "release_version")
|
||||||
|
}
|
||||||
|
host.version.Set(version)
|
||||||
|
case "peer":
|
||||||
|
ip, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "peer")
|
||||||
|
}
|
||||||
|
host.peer = net.ParseIP(ip)
|
||||||
|
case "cluster_name":
|
||||||
|
host.clusterName, ok = value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "cluster_name")
|
||||||
|
}
|
||||||
|
case "partitioner":
|
||||||
|
host.partitioner, ok = value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "partitioner")
|
||||||
|
}
|
||||||
|
case "broadcast_address":
|
||||||
|
ip, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "broadcast_address")
|
||||||
|
}
|
||||||
|
host.broadcastAddress = net.ParseIP(ip)
|
||||||
|
case "preferred_ip":
|
||||||
|
ip, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "preferred_ip")
|
||||||
|
}
|
||||||
|
host.preferredIP = net.ParseIP(ip)
|
||||||
|
case "rpc_address":
|
||||||
|
ip, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "rpc_address")
|
||||||
|
}
|
||||||
|
host.rpcAddress = net.ParseIP(ip)
|
||||||
|
case "listen_address":
|
||||||
|
ip, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "listen_address")
|
||||||
|
}
|
||||||
|
host.listenAddress = net.ParseIP(ip)
|
||||||
|
case "workload":
|
||||||
|
host.workload, ok = value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "workload")
|
||||||
|
}
|
||||||
|
case "graph":
|
||||||
|
host.graph, ok = value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "graph")
|
||||||
|
}
|
||||||
|
case "tokens":
|
||||||
|
host.tokens, ok = value.([]string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "tokens")
|
||||||
|
}
|
||||||
|
case "dse_version":
|
||||||
|
host.dseVersion, ok = value.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "dse_version")
|
||||||
|
}
|
||||||
|
case "schema_version":
|
||||||
|
schemaVersion, ok := value.(UUID)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(assertErrorMsg, "schema_version")
|
||||||
|
}
|
||||||
|
host.schemaVersion = schemaVersion.String()
|
||||||
|
}
|
||||||
|
// TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete
|
||||||
|
// Not sure what the port field will be called until the JIRA issue is complete
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
|
||||||
|
host.connectAddress = ip
|
||||||
|
host.port = port
|
||||||
|
|
||||||
|
return host, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ask the control node for host info on all it's known peers
|
||||||
|
func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
|
||||||
|
var hosts []*HostInfo
|
||||||
|
iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
|
||||||
|
hosts = append(hosts, ch.host)
|
||||||
|
return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
|
||||||
|
})
|
||||||
|
|
||||||
|
if iter == nil {
|
||||||
|
return nil, errNoControl
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := iter.SliceMap()
|
||||||
|
if err != nil {
|
||||||
|
// TODO(zariel): make typed error
|
||||||
|
return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, row := range rows {
|
||||||
|
// extract all available info about the peer
|
||||||
|
host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if !isValidPeer(host) {
|
||||||
|
// If it's not a valid peer
|
||||||
|
Logger.Printf("Found invalid peer '%s' "+
|
||||||
|
"Likely due to a gossip or snitch issue, this host will be ignored", host)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts = append(hosts, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return true if the host is a valid peer
|
||||||
|
func isValidPeer(host *HostInfo) bool {
|
||||||
|
return !(len(host.RPCAddress()) == 0 ||
|
||||||
|
host.hostId == "" ||
|
||||||
|
host.dataCenter == "" ||
|
||||||
|
host.rack == "" ||
|
||||||
|
len(host.tokens) == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a list of hosts the cluster knows about
|
||||||
|
func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
hosts, err := r.getClusterPeerInfo()
|
||||||
|
if err != nil {
|
||||||
|
return r.prevHosts, r.prevPartitioner, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var partitioner string
|
||||||
|
if len(hosts) > 0 {
|
||||||
|
partitioner = hosts[0].Partitioner()
|
||||||
|
}
|
||||||
|
|
||||||
|
return hosts, partitioner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given an ip/port return HostInfo for the specified ip/port
|
||||||
|
func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
|
||||||
|
var host *HostInfo
|
||||||
|
iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
|
||||||
|
if ch.host.ConnectAddress().Equal(ip) {
|
||||||
|
host = ch.host
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
|
||||||
|
})
|
||||||
|
|
||||||
|
if iter != nil {
|
||||||
|
rows, err := iter.SliceMap()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, row := range rows {
|
||||||
|
h, err := r.session.hostInfoFromMap(row, &HostInfo{port: port})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.ConnectAddress().Equal(ip) {
|
||||||
|
host = h
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == nil {
|
||||||
|
return nil, errors.New("host not found in peers table")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == nil {
|
||||||
|
return nil, errors.New("unable to fetch host info: invalid control connection")
|
||||||
|
} else if host.invalidConnectAddr() {
|
||||||
|
return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", ip, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
return host, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ringDescriber) refreshRing() error {
|
||||||
|
// if we have 0 hosts this will return the previous list of hosts to
|
||||||
|
// attempt to reconnect to the cluster otherwise we would never find
|
||||||
|
// downed hosts again, could possibly have an optimisation to only
|
||||||
|
// try to add new hosts if GetHosts didnt error and the hosts didnt change.
|
||||||
|
hosts, partitioner, err := r.GetHosts()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prevHosts := r.session.ring.currentHosts()
|
||||||
|
|
||||||
|
// TODO: move this to session
|
||||||
|
for _, h := range hosts {
|
||||||
|
if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if host, ok := r.session.ring.addHostIfMissing(h); !ok {
|
||||||
|
r.session.pool.addHost(h)
|
||||||
|
r.session.policy.AddHost(h)
|
||||||
|
} else {
|
||||||
|
host.update(h)
|
||||||
|
}
|
||||||
|
delete(prevHosts, h.ConnectAddress().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(zariel): it may be worth having a mutex covering the overall ring state
|
||||||
|
// in a session so that everything sees a consistent state. Becuase as is today
|
||||||
|
// events can come in and due to ordering an UP host could be removed from the cluster
|
||||||
|
for _, host := range prevHosts {
|
||||||
|
r.session.removeHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.session.metadata.setPartitioner(partitioner)
|
||||||
|
r.session.policy.SetPartitioner(partitioner)
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,45 @@
|
|||||||
|
// +build genhostinfo
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gocql/gocql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func gen(clause, field string) {
|
||||||
|
fmt.Printf("if h.%s == %s {\n", field, clause)
|
||||||
|
fmt.Printf("\th.%s = from.%s\n", field, field)
|
||||||
|
fmt.Println("}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type()
|
||||||
|
mu := reflect.TypeOf(sync.RWMutex{})
|
||||||
|
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
f := t.Field(i)
|
||||||
|
if f.Type == mu {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch f.Type.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
gen("nil", f.Name)
|
||||||
|
case reflect.String:
|
||||||
|
gen(`""`, f.Name)
|
||||||
|
case reflect.Int:
|
||||||
|
gen("0", f.Name)
|
||||||
|
case reflect.Struct:
|
||||||
|
gen("("+f.Type.Name()+"{})", f.Name)
|
||||||
|
case reflect.Bool, reflect.Int32:
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown field: %s", f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This is not supposed to be an error-prone script; just a convenience.
|
||||||
|
|
||||||
|
# Install CCM
|
||||||
|
pip install --user cql PyYAML six
|
||||||
|
git clone https://github.com/pcmanus/ccm.git
|
||||||
|
pushd ccm
|
||||||
|
./setup.py install --user
|
||||||
|
popd
|
||||||
|
|
||||||
|
if [ "$1" != "gocql/gocql" ]; then
|
||||||
|
USER=$(echo $1 | cut -f1 -d'/')
|
||||||
|
cd ../..
|
||||||
|
mv ${USER} gocql
|
||||||
|
fi
|
@ -0,0 +1,95 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -eux
|
||||||
|
|
||||||
|
function run_tests() {
|
||||||
|
local clusterSize=3
|
||||||
|
local version=$1
|
||||||
|
local auth=$2
|
||||||
|
|
||||||
|
if [ "$auth" = true ]; then
|
||||||
|
clusterSize=1
|
||||||
|
fi
|
||||||
|
|
||||||
|
local keypath="$(pwd)/testdata/pki"
|
||||||
|
|
||||||
|
local conf=(
|
||||||
|
"client_encryption_options.enabled: true"
|
||||||
|
"client_encryption_options.keystore: $keypath/.keystore"
|
||||||
|
"client_encryption_options.keystore_password: cassandra"
|
||||||
|
"client_encryption_options.require_client_auth: true"
|
||||||
|
"client_encryption_options.truststore: $keypath/.truststore"
|
||||||
|
"client_encryption_options.truststore_password: cassandra"
|
||||||
|
"concurrent_reads: 2"
|
||||||
|
"concurrent_writes: 2"
|
||||||
|
"rpc_server_type: sync"
|
||||||
|
"rpc_min_threads: 2"
|
||||||
|
"rpc_max_threads: 2"
|
||||||
|
"write_request_timeout_in_ms: 5000"
|
||||||
|
"read_request_timeout_in_ms: 5000"
|
||||||
|
)
|
||||||
|
|
||||||
|
ccm remove test || true
|
||||||
|
|
||||||
|
ccm create test -v $version -n $clusterSize -d --vnodes --jvm_arg="-Xmx256m -XX:NewSize=100m"
|
||||||
|
ccm updateconf "${conf[@]}"
|
||||||
|
|
||||||
|
if [ "$auth" = true ]
|
||||||
|
then
|
||||||
|
ccm updateconf 'authenticator: PasswordAuthenticator' 'authorizer: CassandraAuthorizer'
|
||||||
|
rm -rf $HOME/.ccm/test/node1/data/system_auth
|
||||||
|
fi
|
||||||
|
|
||||||
|
local proto=2
|
||||||
|
if [[ $version == 1.2.* ]]; then
|
||||||
|
proto=1
|
||||||
|
elif [[ $version == 2.0.* ]]; then
|
||||||
|
proto=2
|
||||||
|
elif [[ $version == 2.1.* ]]; then
|
||||||
|
proto=3
|
||||||
|
elif [[ $version == 2.2.* || $version == 3.0.* ]]; then
|
||||||
|
proto=4
|
||||||
|
ccm updateconf 'enable_user_defined_functions: true'
|
||||||
|
export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"
|
||||||
|
elif [[ $version == 3.*.* ]]; then
|
||||||
|
proto=5
|
||||||
|
ccm updateconf 'enable_user_defined_functions: true'
|
||||||
|
export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep 1s
|
||||||
|
|
||||||
|
ccm list
|
||||||
|
ccm start --wait-for-binary-proto
|
||||||
|
ccm status
|
||||||
|
ccm node1 nodetool status
|
||||||
|
|
||||||
|
local args="-gocql.timeout=60s -runssl -proto=$proto -rf=3 -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy -gocql.cversion=$version -cluster=$(ccm liveset) ./..."
|
||||||
|
|
||||||
|
go test -v -tags unit -race
|
||||||
|
|
||||||
|
if [ "$auth" = true ]
|
||||||
|
then
|
||||||
|
sleep 30s
|
||||||
|
go test -run=TestAuthentication -tags "integration gocql_debug" -timeout=15s -runauth $args
|
||||||
|
else
|
||||||
|
sleep 1s
|
||||||
|
go test -tags "cassandra gocql_debug" -timeout=5m -race $args
|
||||||
|
|
||||||
|
ccm clear
|
||||||
|
ccm start --wait-for-binary-proto
|
||||||
|
sleep 1s
|
||||||
|
|
||||||
|
go test -tags "integration gocql_debug" -timeout=5m -race $args
|
||||||
|
|
||||||
|
ccm clear
|
||||||
|
ccm start --wait-for-binary-proto
|
||||||
|
sleep 1s
|
||||||
|
|
||||||
|
go test -tags "ccm gocql_debug" -timeout=5m -race $args
|
||||||
|
fi
|
||||||
|
|
||||||
|
ccm remove
|
||||||
|
}
|
||||||
|
|
||||||
|
run_tests $1 $2
|
@ -0,0 +1,127 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2015 To gocql authors
|
||||||
|
Copyright 2013 Google Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package lru implements an LRU cache.
|
||||||
|
package lru
|
||||||
|
|
||||||
|
import "container/list"
|
||||||
|
|
||||||
|
// Cache is an LRU cache. It is not safe for concurrent access.
|
||||||
|
//
|
||||||
|
// This cache has been forked from github.com/golang/groupcache/lru, but
|
||||||
|
// specialized with string keys to avoid the allocations caused by wrapping them
|
||||||
|
// in interface{}.
|
||||||
|
type Cache struct {
|
||||||
|
// MaxEntries is the maximum number of cache entries before
|
||||||
|
// an item is evicted. Zero means no limit.
|
||||||
|
MaxEntries int
|
||||||
|
|
||||||
|
// OnEvicted optionally specifies a callback function to be
|
||||||
|
// executed when an entry is purged from the cache.
|
||||||
|
OnEvicted func(key string, value interface{})
|
||||||
|
|
||||||
|
ll *list.List
|
||||||
|
cache map[string]*list.Element
|
||||||
|
}
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
key string
|
||||||
|
value interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Cache.
|
||||||
|
// If maxEntries is zero, the cache has no limit and it's assumed
|
||||||
|
// that eviction is done by the caller.
|
||||||
|
func New(maxEntries int) *Cache {
|
||||||
|
return &Cache{
|
||||||
|
MaxEntries: maxEntries,
|
||||||
|
ll: list.New(),
|
||||||
|
cache: make(map[string]*list.Element),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a value to the cache.
|
||||||
|
func (c *Cache) Add(key string, value interface{}) {
|
||||||
|
if c.cache == nil {
|
||||||
|
c.cache = make(map[string]*list.Element)
|
||||||
|
c.ll = list.New()
|
||||||
|
}
|
||||||
|
if ee, ok := c.cache[key]; ok {
|
||||||
|
c.ll.MoveToFront(ee)
|
||||||
|
ee.Value.(*entry).value = value
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ele := c.ll.PushFront(&entry{key, value})
|
||||||
|
c.cache[key] = ele
|
||||||
|
if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries {
|
||||||
|
c.RemoveOldest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get looks up a key's value from the cache.
|
||||||
|
func (c *Cache) Get(key string) (value interface{}, ok bool) {
|
||||||
|
if c.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ele, hit := c.cache[key]; hit {
|
||||||
|
c.ll.MoveToFront(ele)
|
||||||
|
return ele.Value.(*entry).value, true
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes the provided key from the cache.
|
||||||
|
func (c *Cache) Remove(key string) bool {
|
||||||
|
if c.cache == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if ele, hit := c.cache[key]; hit {
|
||||||
|
c.removeElement(ele)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOldest removes the oldest item from the cache.
|
||||||
|
func (c *Cache) RemoveOldest() {
|
||||||
|
if c.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ele := c.ll.Back()
|
||||||
|
if ele != nil {
|
||||||
|
c.removeElement(ele)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) removeElement(e *list.Element) {
|
||||||
|
c.ll.Remove(e)
|
||||||
|
kv := e.Value.(*entry)
|
||||||
|
delete(c.cache, kv.key)
|
||||||
|
if c.OnEvicted != nil {
|
||||||
|
c.OnEvicted(kv.key, kv.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of items in the cache.
|
||||||
|
func (c *Cache) Len() int {
|
||||||
|
if c.cache == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.ll.Len()
|
||||||
|
}
|
@ -0,0 +1,135 @@
|
|||||||
|
package murmur
|
||||||
|
|
||||||
|
const (
|
||||||
|
c1 int64 = -8663945395140668459 // 0x87c37b91114253d5
|
||||||
|
c2 int64 = 5545529020109919103 // 0x4cf5ad432745937f
|
||||||
|
fmix1 int64 = -49064778989728563 // 0xff51afd7ed558ccd
|
||||||
|
fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53
|
||||||
|
)
|
||||||
|
|
||||||
|
func fmix(n int64) int64 {
|
||||||
|
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
|
||||||
|
n ^= int64(uint64(n) >> 33)
|
||||||
|
n *= fmix1
|
||||||
|
n ^= int64(uint64(n) >> 33)
|
||||||
|
n *= fmix2
|
||||||
|
n ^= int64(uint64(n) >> 33)
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func block(p byte) int64 {
|
||||||
|
return int64(int8(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
func rotl(x int64, r uint8) int64 {
|
||||||
|
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
|
||||||
|
return (x << r) | (int64)((uint64(x) >> (64 - r)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Murmur3H1(data []byte) int64 {
|
||||||
|
length := len(data)
|
||||||
|
|
||||||
|
var h1, h2, k1, k2 int64
|
||||||
|
|
||||||
|
// body
|
||||||
|
nBlocks := length / 16
|
||||||
|
for i := 0; i < nBlocks; i++ {
|
||||||
|
k1, k2 = getBlock(data, i)
|
||||||
|
|
||||||
|
k1 *= c1
|
||||||
|
k1 = rotl(k1, 31)
|
||||||
|
k1 *= c2
|
||||||
|
h1 ^= k1
|
||||||
|
|
||||||
|
h1 = rotl(h1, 27)
|
||||||
|
h1 += h2
|
||||||
|
h1 = h1*5 + 0x52dce729
|
||||||
|
|
||||||
|
k2 *= c2
|
||||||
|
k2 = rotl(k2, 33)
|
||||||
|
k2 *= c1
|
||||||
|
h2 ^= k2
|
||||||
|
|
||||||
|
h2 = rotl(h2, 31)
|
||||||
|
h2 += h1
|
||||||
|
h2 = h2*5 + 0x38495ab5
|
||||||
|
}
|
||||||
|
|
||||||
|
// tail
|
||||||
|
tail := data[nBlocks*16:]
|
||||||
|
k1 = 0
|
||||||
|
k2 = 0
|
||||||
|
switch length & 15 {
|
||||||
|
case 15:
|
||||||
|
k2 ^= block(tail[14]) << 48
|
||||||
|
fallthrough
|
||||||
|
case 14:
|
||||||
|
k2 ^= block(tail[13]) << 40
|
||||||
|
fallthrough
|
||||||
|
case 13:
|
||||||
|
k2 ^= block(tail[12]) << 32
|
||||||
|
fallthrough
|
||||||
|
case 12:
|
||||||
|
k2 ^= block(tail[11]) << 24
|
||||||
|
fallthrough
|
||||||
|
case 11:
|
||||||
|
k2 ^= block(tail[10]) << 16
|
||||||
|
fallthrough
|
||||||
|
case 10:
|
||||||
|
k2 ^= block(tail[9]) << 8
|
||||||
|
fallthrough
|
||||||
|
case 9:
|
||||||
|
k2 ^= block(tail[8])
|
||||||
|
|
||||||
|
k2 *= c2
|
||||||
|
k2 = rotl(k2, 33)
|
||||||
|
k2 *= c1
|
||||||
|
h2 ^= k2
|
||||||
|
|
||||||
|
fallthrough
|
||||||
|
case 8:
|
||||||
|
k1 ^= block(tail[7]) << 56
|
||||||
|
fallthrough
|
||||||
|
case 7:
|
||||||
|
k1 ^= block(tail[6]) << 48
|
||||||
|
fallthrough
|
||||||
|
case 6:
|
||||||
|
k1 ^= block(tail[5]) << 40
|
||||||
|
fallthrough
|
||||||
|
case 5:
|
||||||
|
k1 ^= block(tail[4]) << 32
|
||||||
|
fallthrough
|
||||||
|
case 4:
|
||||||
|
k1 ^= block(tail[3]) << 24
|
||||||
|
fallthrough
|
||||||
|
case 3:
|
||||||
|
k1 ^= block(tail[2]) << 16
|
||||||
|
fallthrough
|
||||||
|
case 2:
|
||||||
|
k1 ^= block(tail[1]) << 8
|
||||||
|
fallthrough
|
||||||
|
case 1:
|
||||||
|
k1 ^= block(tail[0])
|
||||||
|
|
||||||
|
k1 *= c1
|
||||||
|
k1 = rotl(k1, 31)
|
||||||
|
k1 *= c2
|
||||||
|
h1 ^= k1
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 ^= int64(length)
|
||||||
|
h2 ^= int64(length)
|
||||||
|
|
||||||
|
h1 += h2
|
||||||
|
h2 += h1
|
||||||
|
|
||||||
|
h1 = fmix(h1)
|
||||||
|
h2 = fmix(h2)
|
||||||
|
|
||||||
|
h1 += h2
|
||||||
|
// the following is extraneous since h2 is discarded
|
||||||
|
// h2 += h1
|
||||||
|
|
||||||
|
return h1
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
// +build appengine
|
||||||
|
|
||||||
|
package murmur
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
func getBlock(data []byte, n int) (int64, int64) {
|
||||||
|
k1 := binary.LittleEndian.Int64(data[n*16:])
|
||||||
|
k2 := binary.LittleEndian.Int64(data[(n*16)+8:])
|
||||||
|
return k1, k2
|
||||||
|
}
|
@ -0,0 +1,15 @@
|
|||||||
|
// +build !appengine
|
||||||
|
|
||||||
|
package murmur
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getBlock(data []byte, n int) (int64, int64) {
|
||||||
|
block := (*[2]int64)(unsafe.Pointer(&data[n*16]))
|
||||||
|
|
||||||
|
k1 := block[0]
|
||||||
|
k2 := block[1]
|
||||||
|
return k1, k2
|
||||||
|
}
|
@ -0,0 +1,140 @@
|
|||||||
|
package streams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
const bucketBits = 64
|
||||||
|
|
||||||
|
// IDGenerator tracks and allocates streams which are in use.
|
||||||
|
type IDGenerator struct {
|
||||||
|
NumStreams int
|
||||||
|
inuseStreams int32
|
||||||
|
numBuckets uint32
|
||||||
|
|
||||||
|
// streams is a bitset where each bit represents a stream, a 1 implies in use
|
||||||
|
streams []uint64
|
||||||
|
offset uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(protocol int) *IDGenerator {
|
||||||
|
maxStreams := 128
|
||||||
|
if protocol > 2 {
|
||||||
|
maxStreams = 32768
|
||||||
|
}
|
||||||
|
|
||||||
|
buckets := maxStreams / 64
|
||||||
|
// reserve stream 0
|
||||||
|
streams := make([]uint64, buckets)
|
||||||
|
streams[0] = 1 << 63
|
||||||
|
|
||||||
|
return &IDGenerator{
|
||||||
|
NumStreams: maxStreams,
|
||||||
|
streams: streams,
|
||||||
|
numBuckets: uint32(buckets),
|
||||||
|
offset: uint32(buckets) - 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamFromBucket(bucket, streamInBucket int) int {
|
||||||
|
return (bucket * bucketBits) + streamInBucket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IDGenerator) GetStream() (int, bool) {
|
||||||
|
// based closely on the java-driver stream ID generator
|
||||||
|
// avoid false sharing subsequent requests.
|
||||||
|
offset := atomic.LoadUint32(&s.offset)
|
||||||
|
for !atomic.CompareAndSwapUint32(&s.offset, offset, (offset+1)%s.numBuckets) {
|
||||||
|
offset = atomic.LoadUint32(&s.offset)
|
||||||
|
}
|
||||||
|
offset = (offset + 1) % s.numBuckets
|
||||||
|
|
||||||
|
for i := uint32(0); i < s.numBuckets; i++ {
|
||||||
|
pos := int((i + offset) % s.numBuckets)
|
||||||
|
|
||||||
|
bucket := atomic.LoadUint64(&s.streams[pos])
|
||||||
|
if bucket == math.MaxUint64 {
|
||||||
|
// all streams in use
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < bucketBits; j++ {
|
||||||
|
mask := uint64(1 << streamOffset(j))
|
||||||
|
for bucket&mask == 0 {
|
||||||
|
if atomic.CompareAndSwapUint64(&s.streams[pos], bucket, bucket|mask) {
|
||||||
|
atomic.AddInt32(&s.inuseStreams, 1)
|
||||||
|
return streamFromBucket(int(pos), j), true
|
||||||
|
}
|
||||||
|
bucket = atomic.LoadUint64(&s.streams[pos])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func bitfmt(b uint64) string {
|
||||||
|
return strconv.FormatUint(b, 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the bucket offset of a given stream
|
||||||
|
func bucketOffset(i int) int {
|
||||||
|
return i / bucketBits
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamOffset(stream int) uint64 {
|
||||||
|
return bucketBits - uint64(stream%bucketBits) - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSet(bits uint64, stream int) bool {
|
||||||
|
return bits>>streamOffset(stream)&1 == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IDGenerator) isSet(stream int) bool {
|
||||||
|
bits := atomic.LoadUint64(&s.streams[bucketOffset(stream)])
|
||||||
|
return isSet(bits, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IDGenerator) String() string {
|
||||||
|
size := s.numBuckets * (bucketBits + 1)
|
||||||
|
buf := make([]byte, 0, size)
|
||||||
|
for i := 0; i < int(s.numBuckets); i++ {
|
||||||
|
bits := atomic.LoadUint64(&s.streams[i])
|
||||||
|
buf = append(buf, bitfmt(bits)...)
|
||||||
|
buf = append(buf, ' ')
|
||||||
|
}
|
||||||
|
return string(buf[: size-1 : size-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IDGenerator) Clear(stream int) (inuse bool) {
|
||||||
|
offset := bucketOffset(stream)
|
||||||
|
bucket := atomic.LoadUint64(&s.streams[offset])
|
||||||
|
|
||||||
|
mask := uint64(1) << streamOffset(stream)
|
||||||
|
if bucket&mask != mask {
|
||||||
|
// already cleared
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for !atomic.CompareAndSwapUint64(&s.streams[offset], bucket, bucket & ^mask) {
|
||||||
|
bucket = atomic.LoadUint64(&s.streams[offset])
|
||||||
|
if bucket&mask != mask {
|
||||||
|
// already cleared
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: make this account for 0 stream being reserved
|
||||||
|
if atomic.AddInt32(&s.inuseStreams, -1) < 0 {
|
||||||
|
// TODO(zariel): remove this
|
||||||
|
panic("negative streams inuse")
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IDGenerator) Available() int {
|
||||||
|
return s.NumStreams - int(atomic.LoadInt32(&s.inuseStreams)) - 1
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StdLogger interface {
|
||||||
|
Print(v ...interface{})
|
||||||
|
Printf(format string, v ...interface{})
|
||||||
|
Println(v ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type testLogger struct {
|
||||||
|
capture bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) }
|
||||||
|
func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) }
|
||||||
|
func (l *testLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) }
|
||||||
|
func (l *testLogger) String() string { return l.capture.String() }
|
||||||
|
|
||||||
|
type defaultLogger struct{}
|
||||||
|
|
||||||
|
func (l *defaultLogger) Print(v ...interface{}) { log.Print(v...) }
|
||||||
|
func (l *defaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) }
|
||||||
|
func (l *defaultLogger) Println(v ...interface{}) { log.Println(v...) }
|
||||||
|
|
||||||
|
var Logger StdLogger = &defaultLogger{}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,971 @@
|
|||||||
|
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
//This file will be the future home for more policies
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hailocab/go-hostpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// cowHostList implements a copy on write host list, its equivalent type is []*HostInfo
|
||||||
|
type cowHostList struct {
|
||||||
|
list atomic.Value
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cowHostList) String() string {
|
||||||
|
return fmt.Sprintf("%+v", c.get())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cowHostList) get() []*HostInfo {
|
||||||
|
// TODO(zariel): should we replace this with []*HostInfo?
|
||||||
|
l, ok := c.list.Load().(*[]*HostInfo)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cowHostList) set(list []*HostInfo) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.list.Store(&list)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// add will add a host if it not already in the list
|
||||||
|
func (c *cowHostList) add(host *HostInfo) bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
l := c.get()
|
||||||
|
|
||||||
|
if n := len(l); n == 0 {
|
||||||
|
l = []*HostInfo{host}
|
||||||
|
} else {
|
||||||
|
newL := make([]*HostInfo, n+1)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if host.Equal(l[i]) {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
newL[i] = l[i]
|
||||||
|
}
|
||||||
|
newL[n] = host
|
||||||
|
l = newL
|
||||||
|
}
|
||||||
|
|
||||||
|
c.list.Store(&l)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cowHostList) update(host *HostInfo) {
|
||||||
|
c.mu.Lock()
|
||||||
|
l := c.get()
|
||||||
|
|
||||||
|
if len(l) == 0 {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
newL := make([]*HostInfo, len(l))
|
||||||
|
for i := range l {
|
||||||
|
if host.Equal(l[i]) {
|
||||||
|
newL[i] = host
|
||||||
|
found = true
|
||||||
|
} else {
|
||||||
|
newL[i] = l[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if found {
|
||||||
|
c.list.Store(&newL)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cowHostList) remove(ip net.IP) bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
l := c.get()
|
||||||
|
size := len(l)
|
||||||
|
if size == 0 {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
newL := make([]*HostInfo, 0, size)
|
||||||
|
for i := 0; i < len(l); i++ {
|
||||||
|
if !l[i].ConnectAddress().Equal(ip) {
|
||||||
|
newL = append(newL, l[i])
|
||||||
|
} else {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
newL = newL[: size-1 : size-1]
|
||||||
|
c.list.Store(&newL)
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryableQuery is an interface that represents a query or batch statement that
|
||||||
|
// exposes the correct functions for the retry policy logic to evaluate correctly.
|
||||||
|
type RetryableQuery interface {
|
||||||
|
Attempts() int
|
||||||
|
SetConsistency(c Consistency)
|
||||||
|
GetConsistency() Consistency
|
||||||
|
Context() context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetryType uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
Retry RetryType = 0x00 // retry on same connection
|
||||||
|
RetryNextHost RetryType = 0x01 // retry on another connection
|
||||||
|
Ignore RetryType = 0x02 // ignore error and return result
|
||||||
|
Rethrow RetryType = 0x03 // raise error and stop retrying
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrUnknownRetryType is returned if the retry policy returns a retry type
|
||||||
|
// unknown to the query executor.
|
||||||
|
var ErrUnknownRetryType = errors.New("unknown retry type returned by retry policy")
|
||||||
|
|
||||||
|
// RetryPolicy interface is used by gocql to determine if a query can be attempted
|
||||||
|
// again after a retryable error has been received. The interface allows gocql
|
||||||
|
// users to implement their own logic to determine if a query can be attempted
|
||||||
|
// again.
|
||||||
|
//
|
||||||
|
// See SimpleRetryPolicy as an example of implementing and using a RetryPolicy
|
||||||
|
// interface.
|
||||||
|
type RetryPolicy interface {
|
||||||
|
Attempt(RetryableQuery) bool
|
||||||
|
GetRetryType(error) RetryType
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleRetryPolicy has simple logic for attempting a query a fixed number of times.
|
||||||
|
//
|
||||||
|
// See below for examples of usage:
|
||||||
|
//
|
||||||
|
// //Assign to the cluster
|
||||||
|
// cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3}
|
||||||
|
//
|
||||||
|
// //Assign to a query
|
||||||
|
// query.RetryPolicy(&gocql.SimpleRetryPolicy{NumRetries: 1})
|
||||||
|
//
|
||||||
|
type SimpleRetryPolicy struct {
|
||||||
|
NumRetries int //Number of times to retry a query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt tells gocql to attempt the query again based on query.Attempts being less
|
||||||
|
// than the NumRetries defined in the policy.
|
||||||
|
func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
|
||||||
|
return q.Attempts() <= s.NumRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType {
|
||||||
|
return RetryNextHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExponentialBackoffRetryPolicy sleeps between attempts
|
||||||
|
type ExponentialBackoffRetryPolicy struct {
|
||||||
|
NumRetries int
|
||||||
|
Min, Max time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExponentialBackoffRetryPolicy) Attempt(q RetryableQuery) bool {
|
||||||
|
if q.Attempts() > e.NumRetries {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
time.Sleep(e.napTime(q.Attempts()))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// used to calculate exponentially growing time
|
||||||
|
func getExponentialTime(min time.Duration, max time.Duration, attempts int) time.Duration {
|
||||||
|
if min <= 0 {
|
||||||
|
min = 100 * time.Millisecond
|
||||||
|
}
|
||||||
|
if max <= 0 {
|
||||||
|
max = 10 * time.Second
|
||||||
|
}
|
||||||
|
minFloat := float64(min)
|
||||||
|
napDuration := minFloat * math.Pow(2, float64(attempts-1))
|
||||||
|
// add some jitter
|
||||||
|
napDuration += rand.Float64()*minFloat - (minFloat / 2)
|
||||||
|
if napDuration > float64(max) {
|
||||||
|
return time.Duration(max)
|
||||||
|
}
|
||||||
|
return time.Duration(napDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType {
|
||||||
|
return RetryNextHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// DowngradingConsistencyRetryPolicy: Next retry will be with the next consistency level
|
||||||
|
// provided in the slice
|
||||||
|
//
|
||||||
|
// On a read timeout: the operation is retried with the next provided consistency
|
||||||
|
// level.
|
||||||
|
//
|
||||||
|
// On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH`
|
||||||
|
// and at least one replica acknowledged the write, the operation is
|
||||||
|
// retried with the next consistency level. Furthermore, for other
|
||||||
|
// write types, if at least one replica acknowledged the write, the
|
||||||
|
// timeout is ignored.
|
||||||
|
//
|
||||||
|
// On an unavailable exception: if at least one replica is alive, the
|
||||||
|
// operation is retried with the next provided consistency level.
|
||||||
|
|
||||||
|
type DowngradingConsistencyRetryPolicy struct {
|
||||||
|
ConsistencyLevelsToTry []Consistency
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DowngradingConsistencyRetryPolicy) Attempt(q RetryableQuery) bool {
|
||||||
|
currentAttempt := q.Attempts()
|
||||||
|
|
||||||
|
if currentAttempt > len(d.ConsistencyLevelsToTry) {
|
||||||
|
return false
|
||||||
|
} else if currentAttempt > 0 {
|
||||||
|
q.SetConsistency(d.ConsistencyLevelsToTry[currentAttempt-1])
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("%T: set consistency to %q\n",
|
||||||
|
d,
|
||||||
|
d.ConsistencyLevelsToTry[currentAttempt-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType {
|
||||||
|
switch t := err.(type) {
|
||||||
|
case *RequestErrUnavailable:
|
||||||
|
if t.Alive > 0 {
|
||||||
|
return Retry
|
||||||
|
}
|
||||||
|
return Rethrow
|
||||||
|
case *RequestErrWriteTimeout:
|
||||||
|
if t.WriteType == "SIMPLE" || t.WriteType == "BATCH" || t.WriteType == "COUNTER" {
|
||||||
|
if t.Received > 0 {
|
||||||
|
return Ignore
|
||||||
|
}
|
||||||
|
return Rethrow
|
||||||
|
}
|
||||||
|
if t.WriteType == "UNLOGGED_BATCH" {
|
||||||
|
return Retry
|
||||||
|
}
|
||||||
|
return Rethrow
|
||||||
|
case *RequestErrReadTimeout:
|
||||||
|
return Retry
|
||||||
|
default:
|
||||||
|
return RetryNextHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration {
|
||||||
|
return getExponentialTime(e.Min, e.Max, attempts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type HostStateNotifier interface {
|
||||||
|
AddHost(host *HostInfo)
|
||||||
|
RemoveHost(host *HostInfo)
|
||||||
|
HostUp(host *HostInfo)
|
||||||
|
HostDown(host *HostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyspaceUpdateEvent struct {
|
||||||
|
Keyspace string
|
||||||
|
Change string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostSelectionPolicy is an interface for selecting
|
||||||
|
// the most appropriate host to execute a given query.
|
||||||
|
type HostSelectionPolicy interface {
|
||||||
|
HostStateNotifier
|
||||||
|
SetPartitioner
|
||||||
|
KeyspaceChanged(KeyspaceUpdateEvent)
|
||||||
|
Init(*Session)
|
||||||
|
IsLocal(host *HostInfo) bool
|
||||||
|
//Pick returns an iteration function over selected hosts
|
||||||
|
Pick(ExecutableQuery) NextHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectedHost is an interface returned when picking a host from a host
|
||||||
|
// selection policy.
|
||||||
|
type SelectedHost interface {
|
||||||
|
Info() *HostInfo
|
||||||
|
Mark(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type selectedHost HostInfo
|
||||||
|
|
||||||
|
func (host *selectedHost) Info() *HostInfo {
|
||||||
|
return (*HostInfo)(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (host *selectedHost) Mark(err error) {}
|
||||||
|
|
||||||
|
// NextHost is an iteration function over picked hosts
|
||||||
|
type NextHost func() SelectedHost
|
||||||
|
|
||||||
|
// RoundRobinHostPolicy is a round-robin load balancing policy, where each host
|
||||||
|
// is tried sequentially for each query.
|
||||||
|
func RoundRobinHostPolicy() HostSelectionPolicy {
|
||||||
|
return &roundRobinHostPolicy{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundRobinHostPolicy struct {
|
||||||
|
hosts cowHostList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *roundRobinHostPolicy) IsLocal(*HostInfo) bool { return true }
|
||||||
|
func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
|
||||||
|
func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {}
|
||||||
|
func (r *roundRobinHostPolicy) Init(*Session) {}
|
||||||
|
|
||||||
|
func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||||
|
src := r.hosts.get()
|
||||||
|
hosts := make([]*HostInfo, len(src))
|
||||||
|
copy(hosts, src)
|
||||||
|
|
||||||
|
rand := rand.New(randSource())
|
||||||
|
rand.Shuffle(len(hosts), func(i, j int) {
|
||||||
|
hosts[i], hosts[j] = hosts[j], hosts[i]
|
||||||
|
})
|
||||||
|
|
||||||
|
return roundRobbin(hosts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
|
||||||
|
r.hosts.add(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *roundRobinHostPolicy) RemoveHost(host *HostInfo) {
|
||||||
|
r.hosts.remove(host.ConnectAddress())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
|
||||||
|
r.AddHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *roundRobinHostPolicy) HostDown(host *HostInfo) {
|
||||||
|
r.RemoveHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShuffleReplicas() func(*tokenAwareHostPolicy) {
|
||||||
|
return func(t *tokenAwareHostPolicy) {
|
||||||
|
t.shuffleReplicas = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NonLocalReplicasFallback enables fallback to replicas that are not considered local.
|
||||||
|
//
|
||||||
|
// TokenAwareHostPolicy used with DCAwareHostPolicy fallback first selects replicas by partition key in local DC, then
|
||||||
|
// falls back to other nodes in the local DC. Enabling NonLocalReplicasFallback causes TokenAwareHostPolicy
|
||||||
|
// to first select replicas by partition key in local DC, then replicas by partition key in remote DCs and fall back
|
||||||
|
// to other nodes in local DC.
|
||||||
|
func NonLocalReplicasFallback() func(policy *tokenAwareHostPolicy) {
|
||||||
|
return func(t *tokenAwareHostPolicy) {
|
||||||
|
t.nonLocalReplicasFallback = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenAwareHostPolicy is a token aware host selection policy, where hosts are
|
||||||
|
// selected based on the partition key, so queries are sent to the host which
|
||||||
|
// owns the partition. Fallback is used when routing information is not available.
|
||||||
|
func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAwareHostPolicy)) HostSelectionPolicy {
|
||||||
|
p := &tokenAwareHostPolicy{fallback: fallback}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(p)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// clusterMeta holds metadata about cluster topology.
|
||||||
|
// It is used inside atomic.Value and shallow copies are used when replacing it,
|
||||||
|
// so fields should not be modified in-place. Instead, to modify a field a copy of the field should be made
|
||||||
|
// and the pointer in clusterMeta updated to point to the new value.
|
||||||
|
type clusterMeta struct {
|
||||||
|
// replicas is map[keyspace]map[token]hosts
|
||||||
|
replicas map[string]tokenRingReplicas
|
||||||
|
tokenRing *tokenRing
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenAwareHostPolicy struct {
|
||||||
|
fallback HostSelectionPolicy
|
||||||
|
getKeyspaceMetadata func(keyspace string) (*KeyspaceMetadata, error)
|
||||||
|
getKeyspaceName func() string
|
||||||
|
|
||||||
|
shuffleReplicas bool
|
||||||
|
nonLocalReplicasFallback bool
|
||||||
|
|
||||||
|
// mu protects writes to hosts, partitioner, metadata.
|
||||||
|
// reads can be unlocked as long as they are not used for updating state later.
|
||||||
|
mu sync.Mutex
|
||||||
|
hosts cowHostList
|
||||||
|
partitioner string
|
||||||
|
metadata atomic.Value // *clusterMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) Init(s *Session) {
|
||||||
|
t.getKeyspaceMetadata = s.KeyspaceMetadata
|
||||||
|
t.getKeyspaceName = func() string { return s.cfg.Keyspace }
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
|
||||||
|
return t.fallback.IsLocal(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
meta := t.getMetadataForUpdate()
|
||||||
|
t.updateReplicas(meta, update.Keyspace)
|
||||||
|
t.metadata.Store(meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateReplicas updates replicas in clusterMeta.
|
||||||
|
// It must be called with t.mu mutex locked.
|
||||||
|
// meta must not be nil and it's replicas field will be updated.
|
||||||
|
func (t *tokenAwareHostPolicy) updateReplicas(meta *clusterMeta, keyspace string) {
|
||||||
|
newReplicas := make(map[string]tokenRingReplicas, len(meta.replicas))
|
||||||
|
|
||||||
|
ks, err := t.getKeyspaceMetadata(keyspace)
|
||||||
|
if err == nil {
|
||||||
|
strat := getStrategy(ks)
|
||||||
|
if strat != nil {
|
||||||
|
if meta != nil && meta.tokenRing != nil {
|
||||||
|
newReplicas[keyspace] = strat.replicaMap(meta.tokenRing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for ks, replicas := range meta.replicas {
|
||||||
|
if ks != keyspace {
|
||||||
|
newReplicas[ks] = replicas
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
meta.replicas = newReplicas
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
if t.partitioner != partitioner {
|
||||||
|
t.fallback.SetPartitioner(partitioner)
|
||||||
|
t.partitioner = partitioner
|
||||||
|
meta := t.getMetadataForUpdate()
|
||||||
|
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||||
|
t.updateReplicas(meta, t.getKeyspaceName())
|
||||||
|
t.metadata.Store(meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.hosts.add(host) {
|
||||||
|
meta := t.getMetadataForUpdate()
|
||||||
|
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||||
|
t.updateReplicas(meta, t.getKeyspaceName())
|
||||||
|
t.metadata.Store(meta)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
t.fallback.AddHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) AddHosts(hosts []*HostInfo) {
|
||||||
|
t.mu.Lock()
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
t.hosts.add(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := t.getMetadataForUpdate()
|
||||||
|
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||||
|
t.updateReplicas(meta, t.getKeyspaceName())
|
||||||
|
t.metadata.Store(meta)
|
||||||
|
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
t.fallback.AddHost(host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.hosts.remove(host.ConnectAddress()) {
|
||||||
|
meta := t.getMetadataForUpdate()
|
||||||
|
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||||
|
t.updateReplicas(meta, t.getKeyspaceName())
|
||||||
|
t.metadata.Store(meta)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
t.fallback.RemoveHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
|
||||||
|
t.fallback.HostUp(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
|
||||||
|
t.fallback.HostDown(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetadataReadOnly returns current cluster metadata.
|
||||||
|
// Metadata uses copy on write, so the returned value should be only used for reading.
|
||||||
|
// To obtain a copy that could be updated, use getMetadataForUpdate instead.
|
||||||
|
func (t *tokenAwareHostPolicy) getMetadataReadOnly() *clusterMeta {
|
||||||
|
meta, _ := t.metadata.Load().(*clusterMeta)
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetadataForUpdate returns clusterMeta suitable for updating.
|
||||||
|
// It is a SHALLOW copy of current metadata in case it was already set or new empty clusterMeta otherwise.
|
||||||
|
// This function should be called with t.mu mutex locked and the mutex should not be released before
|
||||||
|
// storing the new metadata.
|
||||||
|
func (t *tokenAwareHostPolicy) getMetadataForUpdate() *clusterMeta {
|
||||||
|
metaReadOnly := t.getMetadataReadOnly()
|
||||||
|
meta := new(clusterMeta)
|
||||||
|
if metaReadOnly != nil {
|
||||||
|
*meta = *metaReadOnly
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetTokenRing creates a new tokenRing.
|
||||||
|
// It must be called with t.mu locked.
|
||||||
|
func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo) {
|
||||||
|
if partitioner == "" {
|
||||||
|
// partitioner not yet set
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new token ring
|
||||||
|
tokenRing, err := newTokenRing(partitioner, hosts)
|
||||||
|
if err != nil {
|
||||||
|
Logger.Printf("Unable to update the token ring due to error: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// replace the token ring
|
||||||
|
m.tokenRing = tokenRing
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||||
|
if qry == nil {
|
||||||
|
return t.fallback.Pick(qry)
|
||||||
|
}
|
||||||
|
|
||||||
|
routingKey, err := qry.GetRoutingKey()
|
||||||
|
if err != nil {
|
||||||
|
return t.fallback.Pick(qry)
|
||||||
|
} else if routingKey == nil {
|
||||||
|
return t.fallback.Pick(qry)
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := t.getMetadataReadOnly()
|
||||||
|
if meta == nil || meta.tokenRing == nil {
|
||||||
|
return t.fallback.Pick(qry)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := meta.tokenRing.partitioner.Hash(routingKey)
|
||||||
|
ht := meta.replicas[qry.Keyspace()].replicasFor(token)
|
||||||
|
|
||||||
|
var replicas []*HostInfo
|
||||||
|
if ht == nil {
|
||||||
|
host, _ := meta.tokenRing.GetHostForToken(token)
|
||||||
|
replicas = []*HostInfo{host}
|
||||||
|
} else if t.shuffleReplicas {
|
||||||
|
replicas = shuffleHosts(replicas)
|
||||||
|
} else {
|
||||||
|
replicas = ht.hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
fallbackIter NextHost
|
||||||
|
i, j int
|
||||||
|
remote []*HostInfo
|
||||||
|
)
|
||||||
|
|
||||||
|
used := make(map[*HostInfo]bool, len(replicas))
|
||||||
|
return func() SelectedHost {
|
||||||
|
for i < len(replicas) {
|
||||||
|
h := replicas[i]
|
||||||
|
i++
|
||||||
|
|
||||||
|
if !t.fallback.IsLocal(h) {
|
||||||
|
remote = append(remote, h)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.IsUp() {
|
||||||
|
used[h] = true
|
||||||
|
return (*selectedHost)(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.nonLocalReplicasFallback {
|
||||||
|
for j < len(remote) {
|
||||||
|
h := remote[j]
|
||||||
|
j++
|
||||||
|
|
||||||
|
if h.IsUp() {
|
||||||
|
used[h] = true
|
||||||
|
return (*selectedHost)(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if fallbackIter == nil {
|
||||||
|
// fallback
|
||||||
|
fallbackIter = t.fallback.Pick(qry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filter the token aware selected hosts from the fallback hosts
|
||||||
|
for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() {
|
||||||
|
if !used[fallbackHost.Info()] {
|
||||||
|
used[fallbackHost.Info()] = true
|
||||||
|
return fallbackHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostPoolHostPolicy is a host policy which uses the bitly/go-hostpool library
|
||||||
|
// to distribute queries between hosts and prevent sending queries to
|
||||||
|
// unresponsive hosts. When creating the host pool that is passed to the policy
|
||||||
|
// use an empty slice of hosts as the hostpool will be populated later by gocql.
|
||||||
|
// See below for examples of usage:
|
||||||
|
//
|
||||||
|
// // Create host selection policy using a simple host pool
|
||||||
|
// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(hostpool.New(nil))
|
||||||
|
//
|
||||||
|
// // Create host selection policy using an epsilon greedy pool
|
||||||
|
// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(
|
||||||
|
// hostpool.NewEpsilonGreedy(nil, 0, &hostpool.LinearEpsilonValueCalculator{}),
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
func HostPoolHostPolicy(hp hostpool.HostPool) HostSelectionPolicy {
|
||||||
|
return &hostPoolHostPolicy{hostMap: map[string]*HostInfo{}, hp: hp}
|
||||||
|
}
|
||||||
|
|
||||||
|
type hostPoolHostPolicy struct {
|
||||||
|
hp hostpool.HostPool
|
||||||
|
mu sync.RWMutex
|
||||||
|
hostMap map[string]*HostInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *hostPoolHostPolicy) Init(*Session) {}
|
||||||
|
func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
|
||||||
|
func (r *hostPoolHostPolicy) SetPartitioner(string) {}
|
||||||
|
func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true }
|
||||||
|
|
||||||
|
func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
|
||||||
|
peers := make([]string, len(hosts))
|
||||||
|
hostMap := make(map[string]*HostInfo, len(hosts))
|
||||||
|
|
||||||
|
for i, host := range hosts {
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
peers[i] = ip
|
||||||
|
hostMap[ip] = host
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
r.hp.SetHosts(peers)
|
||||||
|
r.hostMap = hostMap
|
||||||
|
r.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
// If the host addr is present and isn't nil return
|
||||||
|
if h, ok := r.hostMap[ip]; ok && h != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// otherwise, add the host to the map
|
||||||
|
r.hostMap[ip] = host
|
||||||
|
// and construct a new peer list to give to the HostPool
|
||||||
|
hosts := make([]string, 0, len(r.hostMap))
|
||||||
|
for addr := range r.hostMap {
|
||||||
|
hosts = append(hosts, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.hp.SetHosts(hosts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *hostPoolHostPolicy) RemoveHost(host *HostInfo) {
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := r.hostMap[ip]; !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(r.hostMap, ip)
|
||||||
|
hosts := make([]string, 0, len(r.hostMap))
|
||||||
|
for _, host := range r.hostMap {
|
||||||
|
hosts = append(hosts, host.ConnectAddress().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
r.hp.SetHosts(hosts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *hostPoolHostPolicy) HostUp(host *HostInfo) {
|
||||||
|
r.AddHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *hostPoolHostPolicy) HostDown(host *HostInfo) {
|
||||||
|
r.RemoveHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||||
|
return func() SelectedHost {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
if len(r.hostMap) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hostR := r.hp.Get()
|
||||||
|
host, ok := r.hostMap[hostR.Host()]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectedHostPoolHost{
|
||||||
|
policy: r,
|
||||||
|
info: host,
|
||||||
|
hostR: hostR,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectedHostPoolHost is a host returned by the hostPoolHostPolicy and
|
||||||
|
// implements the SelectedHost interface
|
||||||
|
type selectedHostPoolHost struct {
|
||||||
|
policy *hostPoolHostPolicy
|
||||||
|
info *HostInfo
|
||||||
|
hostR hostpool.HostPoolResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (host selectedHostPoolHost) Info() *HostInfo {
|
||||||
|
return host.info
|
||||||
|
}
|
||||||
|
|
||||||
|
func (host selectedHostPoolHost) Mark(err error) {
|
||||||
|
ip := host.info.ConnectAddress().String()
|
||||||
|
|
||||||
|
host.policy.mu.RLock()
|
||||||
|
defer host.policy.mu.RUnlock()
|
||||||
|
|
||||||
|
if _, ok := host.policy.hostMap[ip]; !ok {
|
||||||
|
// host was removed between pick and mark
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host.hostR.Mark(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type dcAwareRR struct {
|
||||||
|
local string
|
||||||
|
localHosts cowHostList
|
||||||
|
remoteHosts cowHostList
|
||||||
|
}
|
||||||
|
|
||||||
|
// DCAwareRoundRobinPolicy is a host selection policies which will prioritize and
|
||||||
|
// return hosts which are in the local datacentre before returning hosts in all
|
||||||
|
// other datercentres
|
||||||
|
func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy {
|
||||||
|
return &dcAwareRR{local: localDC}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dcAwareRR) Init(*Session) {}
|
||||||
|
func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
|
||||||
|
func (d *dcAwareRR) SetPartitioner(p string) {}
|
||||||
|
|
||||||
|
func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
|
||||||
|
return host.DataCenter() == d.local
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dcAwareRR) AddHost(host *HostInfo) {
|
||||||
|
if d.IsLocal(host) {
|
||||||
|
d.localHosts.add(host)
|
||||||
|
} else {
|
||||||
|
d.remoteHosts.add(host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dcAwareRR) RemoveHost(host *HostInfo) {
|
||||||
|
if d.IsLocal(host) {
|
||||||
|
d.localHosts.remove(host.ConnectAddress())
|
||||||
|
} else {
|
||||||
|
d.remoteHosts.remove(host.ConnectAddress())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dcAwareRR) HostUp(host *HostInfo) { d.AddHost(host) }
|
||||||
|
func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) }
|
||||||
|
|
||||||
|
var randSeed int64
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
p := make([]byte, 8)
|
||||||
|
if _, err := crand.Read(p); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
randSeed = int64(binary.BigEndian.Uint64(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
func randSource() rand.Source {
|
||||||
|
return rand.NewSource(atomic.AddInt64(&randSeed, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundRobbin(hosts []*HostInfo) NextHost {
|
||||||
|
var i int
|
||||||
|
return func() SelectedHost {
|
||||||
|
for i < len(hosts) {
|
||||||
|
h := hosts[i]
|
||||||
|
i++
|
||||||
|
|
||||||
|
if h.IsUp() {
|
||||||
|
return (*selectedHost)(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
|
||||||
|
local := d.localHosts.get()
|
||||||
|
remote := d.remoteHosts.get()
|
||||||
|
|
||||||
|
hosts := make([]*HostInfo, len(local)+len(remote))
|
||||||
|
n := copy(hosts, local)
|
||||||
|
copy(hosts[n:], remote)
|
||||||
|
|
||||||
|
// TODO: use random chose-2 but that will require plumbing information
|
||||||
|
// about connection/host load to here
|
||||||
|
r := rand.New(randSource())
|
||||||
|
for _, l := range [][]*HostInfo{hosts[:len(local)], hosts[len(local):]} {
|
||||||
|
r.Shuffle(len(l), func(i, j int) {
|
||||||
|
l[i], l[j] = l[j], l[i]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return roundRobbin(hosts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvictionPolicy interface is used by gocql to determine if a host should be
|
||||||
|
// marked as DOWN based on the error and host info
|
||||||
|
type ConvictionPolicy interface {
|
||||||
|
// Implementations should return `true` if the host should be convicted, `false` otherwise.
|
||||||
|
AddFailure(error error, host *HostInfo) bool
|
||||||
|
//Implementations should clear out any convictions or state regarding the host.
|
||||||
|
Reset(host *HostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleConvictionPolicy implements a ConvictionPolicy which convicts all hosts
|
||||||
|
// regardless of error
|
||||||
|
type SimpleConvictionPolicy struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SimpleConvictionPolicy) AddFailure(error error, host *HostInfo) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SimpleConvictionPolicy) Reset(host *HostInfo) {}
|
||||||
|
|
||||||
|
// ReconnectionPolicy interface is used by gocql to determine if reconnection
|
||||||
|
// can be attempted after connection error. The interface allows gocql users
|
||||||
|
// to implement their own logic to determine how to attempt reconnection.
|
||||||
|
//
|
||||||
|
type ReconnectionPolicy interface {
|
||||||
|
GetInterval(currentRetry int) time.Duration
|
||||||
|
GetMaxRetries() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConstantReconnectionPolicy has simple logic for returning a fixed reconnection interval.
|
||||||
|
//
|
||||||
|
// Examples of usage:
|
||||||
|
//
|
||||||
|
// cluster.ReconnectionPolicy = &gocql.ConstantReconnectionPolicy{MaxRetries: 10, Interval: 8 * time.Second}
|
||||||
|
//
|
||||||
|
type ConstantReconnectionPolicy struct {
|
||||||
|
MaxRetries int
|
||||||
|
Interval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConstantReconnectionPolicy) GetInterval(currentRetry int) time.Duration {
|
||||||
|
return c.Interval
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConstantReconnectionPolicy) GetMaxRetries() int {
|
||||||
|
return c.MaxRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExponentialReconnectionPolicy returns a growing reconnection interval.
|
||||||
|
type ExponentialReconnectionPolicy struct {
|
||||||
|
MaxRetries int
|
||||||
|
InitialInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExponentialReconnectionPolicy) GetInterval(currentRetry int) time.Duration {
|
||||||
|
return getExponentialTime(e.InitialInterval, math.MaxInt16*time.Second, currentRetry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExponentialReconnectionPolicy) GetMaxRetries() int {
|
||||||
|
return e.MaxRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
type SpeculativeExecutionPolicy interface {
|
||||||
|
Attempts() int
|
||||||
|
Delay() time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type NonSpeculativeExecution struct{}
|
||||||
|
|
||||||
|
func (sp NonSpeculativeExecution) Attempts() int { return 0 } // No additional attempts
|
||||||
|
func (sp NonSpeculativeExecution) Delay() time.Duration { return 1 } // The delay. Must be positive to be used in a ticker.
|
||||||
|
|
||||||
|
type SimpleSpeculativeExecution struct {
|
||||||
|
NumAttempts int
|
||||||
|
TimeoutDelay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sp *SimpleSpeculativeExecution) Attempts() int { return sp.NumAttempts }
|
||||||
|
func (sp *SimpleSpeculativeExecution) Delay() time.Duration { return sp.TimeoutDelay }
|
@ -0,0 +1,64 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gocql/gocql/internal/lru"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultMaxPreparedStmts = 1000
|
||||||
|
|
||||||
|
// preparedLRU is the prepared statement cache
|
||||||
|
type preparedLRU struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
lru *lru.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max adjusts the maximum size of the cache and cleans up the oldest records if
|
||||||
|
// the new max is lower than the previous value. Not concurrency safe.
|
||||||
|
func (p *preparedLRU) max(max int) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
for p.lru.Len() > max {
|
||||||
|
p.lru.RemoveOldest()
|
||||||
|
}
|
||||||
|
p.lru.MaxEntries = max
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preparedLRU) clear() {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
for p.lru.Len() > 0 {
|
||||||
|
p.lru.RemoveOldest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preparedLRU) add(key string, val *inflightPrepare) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
p.lru.Add(key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preparedLRU) remove(key string) bool {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.lru.Remove(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
val, ok := p.lru.Get(key)
|
||||||
|
if ok {
|
||||||
|
return val.(*inflightPrepare), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn(p.lru), false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preparedLRU) keyFor(addr, keyspace, statement string) string {
|
||||||
|
// TODO: maybe use []byte for keys?
|
||||||
|
return addr + keyspace + statement
|
||||||
|
}
|
@ -0,0 +1,161 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ExecutableQuery interface {
|
||||||
|
execute(ctx context.Context, conn *Conn) *Iter
|
||||||
|
attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
|
||||||
|
retryPolicy() RetryPolicy
|
||||||
|
speculativeExecutionPolicy() SpeculativeExecutionPolicy
|
||||||
|
GetRoutingKey() ([]byte, error)
|
||||||
|
Keyspace() string
|
||||||
|
IsIdempotent() bool
|
||||||
|
|
||||||
|
withContext(context.Context) ExecutableQuery
|
||||||
|
|
||||||
|
RetryableQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryExecutor struct {
|
||||||
|
pool *policyConnPool
|
||||||
|
policy HostSelectionPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter {
|
||||||
|
start := time.Now()
|
||||||
|
iter := qry.execute(ctx, conn)
|
||||||
|
end := time.Now()
|
||||||
|
|
||||||
|
qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
|
||||||
|
|
||||||
|
return iter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, results chan *Iter) *Iter {
|
||||||
|
ticker := time.NewTicker(sp.Delay())
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for i := 0; i < sp.Attempts(); i++ {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
go q.run(ctx, qry, results)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return &Iter{err: ctx.Err()}
|
||||||
|
case iter := <-results:
|
||||||
|
return iter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
||||||
|
// check if the query is not marked as idempotent, if
|
||||||
|
// it is, we force the policy to NonSpeculative
|
||||||
|
sp := qry.speculativeExecutionPolicy()
|
||||||
|
if !qry.IsIdempotent() || sp.Attempts() == 0 {
|
||||||
|
return q.do(qry.Context(), qry), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(qry.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
results := make(chan *Iter, 1)
|
||||||
|
|
||||||
|
// Launch the main execution
|
||||||
|
go q.run(ctx, qry, results)
|
||||||
|
|
||||||
|
// The speculative executions are launched _in addition_ to the main
|
||||||
|
// execution, on a timer. So Speculation{2} would make 3 executions running
|
||||||
|
// in total.
|
||||||
|
if iter := q.speculate(ctx, qry, sp, results); iter != nil {
|
||||||
|
return iter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case iter := <-results:
|
||||||
|
return iter, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return &Iter{err: ctx.Err()}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
|
||||||
|
hostIter := q.policy.Pick(qry)
|
||||||
|
selectedHost := hostIter()
|
||||||
|
rt := qry.retryPolicy()
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
var iter *Iter
|
||||||
|
for selectedHost != nil {
|
||||||
|
host := selectedHost.Info()
|
||||||
|
if host == nil || !host.IsUp() {
|
||||||
|
selectedHost = hostIter()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, ok := q.pool.getPool(host)
|
||||||
|
if !ok {
|
||||||
|
selectedHost = hostIter()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := pool.Pick()
|
||||||
|
if conn == nil {
|
||||||
|
selectedHost = hostIter()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
iter = q.attemptQuery(ctx, qry, conn)
|
||||||
|
iter.host = selectedHost.Info()
|
||||||
|
// Update host
|
||||||
|
switch iter.err {
|
||||||
|
case context.Canceled, context.DeadlineExceeded, ErrNotFound:
|
||||||
|
// those errors represents logical errors, they should not count
|
||||||
|
// toward removing a node from the pool
|
||||||
|
selectedHost.Mark(nil)
|
||||||
|
return iter
|
||||||
|
default:
|
||||||
|
selectedHost.Mark(iter.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit if the query was successful
|
||||||
|
// or no retry policy defined or retry attempts were reached
|
||||||
|
if iter.err == nil || rt == nil || !rt.Attempt(qry) {
|
||||||
|
return iter
|
||||||
|
}
|
||||||
|
lastErr = iter.err
|
||||||
|
|
||||||
|
// If query is unsuccessful, check the error with RetryPolicy to retry
|
||||||
|
switch rt.GetRetryType(iter.err) {
|
||||||
|
case Retry:
|
||||||
|
// retry on the same host
|
||||||
|
continue
|
||||||
|
case Rethrow, Ignore:
|
||||||
|
return iter
|
||||||
|
case RetryNextHost:
|
||||||
|
// retry on the next host
|
||||||
|
selectedHost = hostIter()
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
// Undefined? Return nil and error, this will panic in the requester
|
||||||
|
return &Iter{err: ErrUnknownRetryType}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return &Iter{err: lastErr}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Iter{err: ErrNoConnections}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan<- *Iter) {
|
||||||
|
select {
|
||||||
|
case results <- q.do(ctx, qry):
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,152 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ring struct {
|
||||||
|
// endpoints are the set of endpoints which the driver will attempt to connect
|
||||||
|
// to in the case it can not reach any of its hosts. They are also used to boot
|
||||||
|
// strap the initial connection.
|
||||||
|
endpoints []*HostInfo
|
||||||
|
|
||||||
|
// hosts are the set of all hosts in the cassandra ring that we know of
|
||||||
|
mu sync.RWMutex
|
||||||
|
hosts map[string]*HostInfo
|
||||||
|
|
||||||
|
hostList []*HostInfo
|
||||||
|
pos uint32
|
||||||
|
|
||||||
|
// TODO: we should store the ring metadata here also.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) rrHost() *HostInfo {
|
||||||
|
// TODO: should we filter hosts that get used here? These hosts will be used
|
||||||
|
// for the control connection, should we also provide an iterator?
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
if len(r.hostList) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := int(atomic.AddUint32(&r.pos, 1) - 1)
|
||||||
|
return r.hostList[pos%len(r.hostList)]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) getHost(ip net.IP) *HostInfo {
|
||||||
|
r.mu.RLock()
|
||||||
|
host := r.hosts[ip.String()]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) allHosts() []*HostInfo {
|
||||||
|
r.mu.RLock()
|
||||||
|
hosts := make([]*HostInfo, 0, len(r.hosts))
|
||||||
|
for _, host := range r.hosts {
|
||||||
|
hosts = append(hosts, host)
|
||||||
|
}
|
||||||
|
r.mu.RUnlock()
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) currentHosts() map[string]*HostInfo {
|
||||||
|
r.mu.RLock()
|
||||||
|
hosts := make(map[string]*HostInfo, len(r.hosts))
|
||||||
|
for k, v := range r.hosts {
|
||||||
|
hosts[k] = v
|
||||||
|
}
|
||||||
|
r.mu.RUnlock()
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) addHost(host *HostInfo) bool {
|
||||||
|
// TODO(zariel): key all host info by HostID instead of
|
||||||
|
// ip addresses
|
||||||
|
if host.invalidConnectAddr() {
|
||||||
|
panic(fmt.Sprintf("invalid host: %v", host))
|
||||||
|
}
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.hosts == nil {
|
||||||
|
r.hosts = make(map[string]*HostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := r.hosts[ip]
|
||||||
|
if !ok {
|
||||||
|
r.hostList = append(r.hostList, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.hosts[ip] = host
|
||||||
|
r.mu.Unlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
|
||||||
|
if existingHost, ok := r.addHostIfMissing(host); ok {
|
||||||
|
existingHost.update(host)
|
||||||
|
host = existingHost
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
|
||||||
|
if host.invalidConnectAddr() {
|
||||||
|
panic(fmt.Sprintf("invalid host: %v", host))
|
||||||
|
}
|
||||||
|
ip := host.ConnectAddress().String()
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.hosts == nil {
|
||||||
|
r.hosts = make(map[string]*HostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, ok := r.hosts[ip]
|
||||||
|
if !ok {
|
||||||
|
r.hosts[ip] = host
|
||||||
|
existing = host
|
||||||
|
r.hostList = append(r.hostList, host)
|
||||||
|
}
|
||||||
|
r.mu.Unlock()
|
||||||
|
return existing, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ring) removeHost(ip net.IP) bool {
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.hosts == nil {
|
||||||
|
r.hosts = make(map[string]*HostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
k := ip.String()
|
||||||
|
_, ok := r.hosts[k]
|
||||||
|
if ok {
|
||||||
|
for i, host := range r.hostList {
|
||||||
|
if host.ConnectAddress().Equal(ip) {
|
||||||
|
r.hostList = append(r.hostList[:i], r.hostList[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(r.hosts, k)
|
||||||
|
r.mu.Unlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
type clusterMetadata struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
partitioner string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clusterMetadata) setPartitioner(partitioner string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.partitioner != partitioner {
|
||||||
|
// TODO: update other things now
|
||||||
|
c.partitioner = partitioner
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,223 @@
|
|||||||
|
// Copyright (c) 2015 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/md5"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gocql/gocql/internal/murmur"
|
||||||
|
)
|
||||||
|
|
||||||
|
// a token partitioner
|
||||||
|
type partitioner interface {
|
||||||
|
Name() string
|
||||||
|
Hash([]byte) token
|
||||||
|
ParseString(string) token
|
||||||
|
}
|
||||||
|
|
||||||
|
// a token
|
||||||
|
type token interface {
|
||||||
|
fmt.Stringer
|
||||||
|
Less(token) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// murmur3 partitioner and token
|
||||||
|
type murmur3Partitioner struct{}
|
||||||
|
type murmur3Token int64
|
||||||
|
|
||||||
|
func (p murmur3Partitioner) Name() string {
|
||||||
|
return "Murmur3Partitioner"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p murmur3Partitioner) Hash(partitionKey []byte) token {
|
||||||
|
h1 := murmur.Murmur3H1(partitionKey)
|
||||||
|
return murmur3Token(h1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// murmur3 little-endian, 128-bit hash, but returns only h1
|
||||||
|
func (p murmur3Partitioner) ParseString(str string) token {
|
||||||
|
val, _ := strconv.ParseInt(str, 10, 64)
|
||||||
|
return murmur3Token(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m murmur3Token) String() string {
|
||||||
|
return strconv.FormatInt(int64(m), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m murmur3Token) Less(token token) bool {
|
||||||
|
return m < token.(murmur3Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// order preserving partitioner and token
|
||||||
|
type orderedPartitioner struct{}
|
||||||
|
type orderedToken string
|
||||||
|
|
||||||
|
func (p orderedPartitioner) Name() string {
|
||||||
|
return "OrderedPartitioner"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p orderedPartitioner) Hash(partitionKey []byte) token {
|
||||||
|
// the partition key is the token
|
||||||
|
return orderedToken(partitionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p orderedPartitioner) ParseString(str string) token {
|
||||||
|
return orderedToken(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o orderedToken) String() string {
|
||||||
|
return string(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o orderedToken) Less(token token) bool {
|
||||||
|
return o < token.(orderedToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// random partitioner and token
|
||||||
|
type randomPartitioner struct{}
|
||||||
|
type randomToken big.Int
|
||||||
|
|
||||||
|
func (r randomPartitioner) Name() string {
|
||||||
|
return "RandomPartitioner"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2 ** 128
|
||||||
|
var maxHashInt, _ = new(big.Int).SetString("340282366920938463463374607431768211456", 10)
|
||||||
|
|
||||||
|
func (p randomPartitioner) Hash(partitionKey []byte) token {
|
||||||
|
sum := md5.Sum(partitionKey)
|
||||||
|
val := new(big.Int)
|
||||||
|
val.SetBytes(sum[:])
|
||||||
|
if sum[0] > 127 {
|
||||||
|
val.Sub(val, maxHashInt)
|
||||||
|
val.Abs(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (*randomToken)(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p randomPartitioner) ParseString(str string) token {
|
||||||
|
val := new(big.Int)
|
||||||
|
val.SetString(str, 10)
|
||||||
|
return (*randomToken)(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *randomToken) String() string {
|
||||||
|
return (*big.Int)(r).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *randomToken) Less(token token) bool {
|
||||||
|
return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken)))
|
||||||
|
}
|
||||||
|
|
||||||
|
type hostToken struct {
|
||||||
|
token token
|
||||||
|
host *HostInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ht hostToken) String() string {
|
||||||
|
return fmt.Sprintf("{token=%v host=%v}", ht.token, ht.host.HostID())
|
||||||
|
}
|
||||||
|
|
||||||
|
// a data structure for organizing the relationship between tokens and hosts
|
||||||
|
type tokenRing struct {
|
||||||
|
partitioner partitioner
|
||||||
|
tokens []hostToken
|
||||||
|
hosts []*HostInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
|
||||||
|
tokenRing := &tokenRing{
|
||||||
|
hosts: hosts,
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(partitioner, "Murmur3Partitioner") {
|
||||||
|
tokenRing.partitioner = murmur3Partitioner{}
|
||||||
|
} else if strings.HasSuffix(partitioner, "OrderedPartitioner") {
|
||||||
|
tokenRing.partitioner = orderedPartitioner{}
|
||||||
|
} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
|
||||||
|
tokenRing.partitioner = randomPartitioner{}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
for _, strToken := range host.Tokens() {
|
||||||
|
token := tokenRing.partitioner.ParseString(strToken)
|
||||||
|
tokenRing.tokens = append(tokenRing.tokens, hostToken{token, host})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(tokenRing)
|
||||||
|
|
||||||
|
return tokenRing, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenRing) Len() int {
|
||||||
|
return len(t.tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenRing) Less(i, j int) bool {
|
||||||
|
return t.tokens[i].token.Less(t.tokens[j].token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenRing) Swap(i, j int) {
|
||||||
|
t.tokens[i], t.tokens[j] = t.tokens[j], t.tokens[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenRing) String() string {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
buf.WriteString("TokenRing(")
|
||||||
|
if t.partitioner != nil {
|
||||||
|
buf.WriteString(t.partitioner.Name())
|
||||||
|
}
|
||||||
|
buf.WriteString("){")
|
||||||
|
sep := ""
|
||||||
|
for i, th := range t.tokens {
|
||||||
|
buf.WriteString(sep)
|
||||||
|
sep = ","
|
||||||
|
buf.WriteString("\n\t[")
|
||||||
|
buf.WriteString(strconv.Itoa(i))
|
||||||
|
buf.WriteString("]")
|
||||||
|
buf.WriteString(th.token.String())
|
||||||
|
buf.WriteString(":")
|
||||||
|
buf.WriteString(th.host.ConnectAddress().String())
|
||||||
|
}
|
||||||
|
buf.WriteString("\n}")
|
||||||
|
return string(buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) (host *HostInfo, endToken token) {
|
||||||
|
if t == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.GetHostForToken(t.partitioner.Hash(partitionKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token) {
|
||||||
|
if t == nil || len(t.tokens) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the primary replica
|
||||||
|
p := sort.Search(len(t.tokens), func(i int) bool {
|
||||||
|
return !t.tokens[i].token.Less(token)
|
||||||
|
})
|
||||||
|
|
||||||
|
if p == len(t.tokens) {
|
||||||
|
// wrap around to the first in the ring
|
||||||
|
p = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
v := t.tokens[p]
|
||||||
|
return v.host, v.token
|
||||||
|
}
|
@ -0,0 +1,276 @@
|
|||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type hostTokens struct {
|
||||||
|
token token
|
||||||
|
hosts []*HostInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenRingReplicas []hostTokens
|
||||||
|
|
||||||
|
func (h tokenRingReplicas) Less(i, j int) bool { return h[i].token.Less(h[j].token) }
|
||||||
|
func (h tokenRingReplicas) Len() int { return len(h) }
|
||||||
|
func (h tokenRingReplicas) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||||
|
|
||||||
|
func (h tokenRingReplicas) replicasFor(t token) *hostTokens {
|
||||||
|
if len(h) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p := sort.Search(len(h), func(i int) bool {
|
||||||
|
return !h[i].token.Less(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
// TODO: simplify this
|
||||||
|
if p < len(h) && h[p].token == t {
|
||||||
|
return &h[p]
|
||||||
|
}
|
||||||
|
|
||||||
|
p--
|
||||||
|
|
||||||
|
if p >= len(h) {
|
||||||
|
// rollover
|
||||||
|
p = 0
|
||||||
|
} else if p < 0 {
|
||||||
|
// rollunder
|
||||||
|
p = len(h) - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return &h[p]
|
||||||
|
}
|
||||||
|
|
||||||
|
type placementStrategy interface {
|
||||||
|
replicaMap(tokenRing *tokenRing) tokenRingReplicas
|
||||||
|
replicationFactor(dc string) int
|
||||||
|
}
|
||||||
|
|
||||||
|
func getReplicationFactorFromOpts(keyspace string, val interface{}) int {
|
||||||
|
// TODO: dont really want to panic here, but is better
|
||||||
|
// than spamming
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int:
|
||||||
|
if v <= 0 {
|
||||||
|
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace))
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
case string:
|
||||||
|
n, err := strconv.Atoi(v)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err))
|
||||||
|
} else if n <= 0 {
|
||||||
|
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace))
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unkown replication_factor type %T", v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStrategy(ks *KeyspaceMetadata) placementStrategy {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(ks.StrategyClass, "SimpleStrategy"):
|
||||||
|
return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])}
|
||||||
|
case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"):
|
||||||
|
dcs := make(map[string]int)
|
||||||
|
for dc, rf := range ks.StrategyOptions {
|
||||||
|
if dc == "class" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf)
|
||||||
|
}
|
||||||
|
return &networkTopology{dcs: dcs}
|
||||||
|
case strings.Contains(ks.StrategyClass, "LocalStrategy"):
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
// TODO: handle unknown replicas and just return the primary host for a token
|
||||||
|
panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type simpleStrategy struct {
|
||||||
|
rf int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *simpleStrategy) replicationFactor(dc string) int {
|
||||||
|
return s.rf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *simpleStrategy) replicaMap(tokenRing *tokenRing) tokenRingReplicas {
|
||||||
|
tokens := tokenRing.tokens
|
||||||
|
ring := make(tokenRingReplicas, len(tokens))
|
||||||
|
|
||||||
|
for i, th := range tokens {
|
||||||
|
replicas := make([]*HostInfo, 0, s.rf)
|
||||||
|
seen := make(map[*HostInfo]bool)
|
||||||
|
|
||||||
|
for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ {
|
||||||
|
h := tokens[(i+j)%len(tokens)]
|
||||||
|
if !seen[h.host] {
|
||||||
|
replicas = append(replicas, h.host)
|
||||||
|
seen[h.host] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ring[i] = hostTokens{th.token, replicas}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(ring)
|
||||||
|
|
||||||
|
return ring
|
||||||
|
}
|
||||||
|
|
||||||
|
type networkTopology struct {
|
||||||
|
dcs map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *networkTopology) replicationFactor(dc string) int {
|
||||||
|
return n.dcs[dc]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *networkTopology) haveRF(replicaCounts map[string]int) bool {
|
||||||
|
if len(replicaCounts) != len(n.dcs) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for dc, rf := range n.dcs {
|
||||||
|
if rf != replicaCounts[dc] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *networkTopology) replicaMap(tokenRing *tokenRing) tokenRingReplicas {
|
||||||
|
dcRacks := make(map[string]map[string]struct{}, len(n.dcs))
|
||||||
|
// skipped hosts in a dc
|
||||||
|
skipped := make(map[string][]*HostInfo, len(n.dcs))
|
||||||
|
// number of replicas per dc
|
||||||
|
replicasInDC := make(map[string]int, len(n.dcs))
|
||||||
|
// dc -> racks
|
||||||
|
seenDCRacks := make(map[string]map[string]struct{}, len(n.dcs))
|
||||||
|
|
||||||
|
for _, h := range tokenRing.hosts {
|
||||||
|
dc := h.DataCenter()
|
||||||
|
rack := h.Rack()
|
||||||
|
|
||||||
|
racks, ok := dcRacks[dc]
|
||||||
|
if !ok {
|
||||||
|
racks = make(map[string]struct{})
|
||||||
|
dcRacks[dc] = racks
|
||||||
|
}
|
||||||
|
racks[rack] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for dc, racks := range dcRacks {
|
||||||
|
replicasInDC[dc] = 0
|
||||||
|
seenDCRacks[dc] = make(map[string]struct{}, len(racks))
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := tokenRing.tokens
|
||||||
|
replicaRing := make(tokenRingReplicas, len(tokens))
|
||||||
|
|
||||||
|
var totalRF int
|
||||||
|
for _, rf := range n.dcs {
|
||||||
|
totalRF += rf
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, th := range tokenRing.tokens {
|
||||||
|
for k, v := range skipped {
|
||||||
|
skipped[k] = v[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
for dc := range n.dcs {
|
||||||
|
replicasInDC[dc] = 0
|
||||||
|
for rack := range seenDCRacks[dc] {
|
||||||
|
delete(seenDCRacks[dc], rack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
replicas := make([]*HostInfo, 0, totalRF)
|
||||||
|
for j := 0; j < len(tokens) && (len(replicas) < totalRF && !n.haveRF(replicasInDC)); j++ {
|
||||||
|
// TODO: ensure we dont add the same host twice
|
||||||
|
p := i + j
|
||||||
|
if p >= len(tokens) {
|
||||||
|
p -= len(tokens)
|
||||||
|
}
|
||||||
|
h := tokens[p].host
|
||||||
|
|
||||||
|
dc := h.DataCenter()
|
||||||
|
rack := h.Rack()
|
||||||
|
|
||||||
|
rf, ok := n.dcs[dc]
|
||||||
|
if !ok {
|
||||||
|
// skip this DC, dont know about it
|
||||||
|
continue
|
||||||
|
} else if replicasInDC[dc] >= rf {
|
||||||
|
if replicasInDC[dc] > rf {
|
||||||
|
panic(fmt.Sprintf("replica overflow. rf=%d have=%d in dc %q", rf, replicasInDC[dc], dc))
|
||||||
|
}
|
||||||
|
|
||||||
|
// have enough replicas in this DC
|
||||||
|
continue
|
||||||
|
} else if _, ok := dcRacks[dc][rack]; !ok {
|
||||||
|
// dont know about this rack
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
racks := seenDCRacks[dc]
|
||||||
|
if _, ok := racks[rack]; ok && len(racks) == len(dcRacks[dc]) {
|
||||||
|
// we have been through all the racks and dont have RF yet, add this
|
||||||
|
replicas = append(replicas, h)
|
||||||
|
replicasInDC[dc]++
|
||||||
|
} else if !ok {
|
||||||
|
if racks == nil {
|
||||||
|
racks = make(map[string]struct{}, 1)
|
||||||
|
seenDCRacks[dc] = racks
|
||||||
|
}
|
||||||
|
|
||||||
|
// new rack
|
||||||
|
racks[rack] = struct{}{}
|
||||||
|
replicas = append(replicas, h)
|
||||||
|
r := replicasInDC[dc] + 1
|
||||||
|
|
||||||
|
if len(racks) == len(dcRacks[dc]) {
|
||||||
|
// if we have been through all the racks, drain the rest of the skipped
|
||||||
|
// hosts until we have RF. The next iteration will skip in the block
|
||||||
|
// above
|
||||||
|
skippedHosts := skipped[dc]
|
||||||
|
var k int
|
||||||
|
for ; k < len(skippedHosts) && r+k < rf; k++ {
|
||||||
|
sh := skippedHosts[k]
|
||||||
|
replicas = append(replicas, sh)
|
||||||
|
}
|
||||||
|
r += k
|
||||||
|
skipped[dc] = skippedHosts[k:]
|
||||||
|
}
|
||||||
|
replicasInDC[dc] = r
|
||||||
|
} else {
|
||||||
|
// already seen this rack, keep hold of this host incase
|
||||||
|
// we dont get enough for rf
|
||||||
|
skipped[dc] = append(skipped[dc], h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(replicas) == 0 {
|
||||||
|
panic(fmt.Sprintf("no replicas for token: %v", th.token))
|
||||||
|
} else if !replicas[0].Equal(th.host) {
|
||||||
|
panic(fmt.Sprintf("first replica is not the primary replica for the token: expected %v got %v", replicas[0].ConnectAddress(), th.host.ConnectAddress()))
|
||||||
|
}
|
||||||
|
|
||||||
|
replicaRing[i] = hostTokens{th.token, replicas}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(replicaRing) != len(tokens) {
|
||||||
|
panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(replicaRing), len(tokens)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return replicaRing
|
||||||
|
}
|
@ -0,0 +1,315 @@
|
|||||||
|
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// The uuid package can be used to generate and parse universally unique
|
||||||
|
// identifiers, a standardized format in the form of a 128 bit number.
|
||||||
|
//
|
||||||
|
// http://tools.ietf.org/html/rfc4122
|
||||||
|
package gocql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UUID [16]byte
|
||||||
|
|
||||||
|
var hardwareAddr []byte
|
||||||
|
var clockSeq uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
VariantNCSCompat = 0
|
||||||
|
VariantIETF = 2
|
||||||
|
VariantMicrosoft = 6
|
||||||
|
VariantFuture = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if interfaces, err := net.Interfaces(); err == nil {
|
||||||
|
for _, i := range interfaces {
|
||||||
|
if i.Flags&net.FlagLoopback == 0 && len(i.HardwareAddr) > 0 {
|
||||||
|
hardwareAddr = i.HardwareAddr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hardwareAddr == nil {
|
||||||
|
// If we failed to obtain the MAC address of the current computer,
|
||||||
|
// we will use a randomly generated 6 byte sequence instead and set
|
||||||
|
// the multicast bit as recommended in RFC 4122.
|
||||||
|
hardwareAddr = make([]byte, 6)
|
||||||
|
_, err := io.ReadFull(rand.Reader, hardwareAddr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
hardwareAddr[0] = hardwareAddr[0] | 0x01
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize the clock sequence with a random number
|
||||||
|
var clockSeqRand [2]byte
|
||||||
|
io.ReadFull(rand.Reader, clockSeqRand[:])
|
||||||
|
clockSeq = uint32(clockSeqRand[1])<<8 | uint32(clockSeqRand[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseUUID parses a 32 digit hexadecimal number (that might contain hypens)
|
||||||
|
// representing an UUID.
|
||||||
|
func ParseUUID(input string) (UUID, error) {
|
||||||
|
var u UUID
|
||||||
|
j := 0
|
||||||
|
for _, r := range input {
|
||||||
|
switch {
|
||||||
|
case r == '-' && j&1 == 0:
|
||||||
|
continue
|
||||||
|
case r >= '0' && r <= '9' && j < 32:
|
||||||
|
u[j/2] |= byte(r-'0') << uint(4-j&1*4)
|
||||||
|
case r >= 'a' && r <= 'f' && j < 32:
|
||||||
|
u[j/2] |= byte(r-'a'+10) << uint(4-j&1*4)
|
||||||
|
case r >= 'A' && r <= 'F' && j < 32:
|
||||||
|
u[j/2] |= byte(r-'A'+10) << uint(4-j&1*4)
|
||||||
|
default:
|
||||||
|
return UUID{}, fmt.Errorf("invalid UUID %q", input)
|
||||||
|
}
|
||||||
|
j += 1
|
||||||
|
}
|
||||||
|
if j != 32 {
|
||||||
|
return UUID{}, fmt.Errorf("invalid UUID %q", input)
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UUIDFromBytes converts a raw byte slice to an UUID.
|
||||||
|
func UUIDFromBytes(input []byte) (UUID, error) {
|
||||||
|
var u UUID
|
||||||
|
if len(input) != 16 {
|
||||||
|
return u, errors.New("UUIDs must be exactly 16 bytes long")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(u[:], input)
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomUUID generates a totally random UUID (version 4) as described in
|
||||||
|
// RFC 4122.
|
||||||
|
func RandomUUID() (UUID, error) {
|
||||||
|
var u UUID
|
||||||
|
_, err := io.ReadFull(rand.Reader, u[:])
|
||||||
|
if err != nil {
|
||||||
|
return u, err
|
||||||
|
}
|
||||||
|
u[6] &= 0x0F // clear version
|
||||||
|
u[6] |= 0x40 // set version to 4 (random uuid)
|
||||||
|
u[8] &= 0x3F // clear variant
|
||||||
|
u[8] |= 0x80 // set to IETF variant
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix()
|
||||||
|
|
||||||
|
// getTimestamp converts time to UUID (version 1) timestamp.
|
||||||
|
// It must be an interval of 100-nanoseconds since timeBase.
|
||||||
|
func getTimestamp(t time.Time) int64 {
|
||||||
|
utcTime := t.In(time.UTC)
|
||||||
|
ts := int64(utcTime.Unix()-timeBase)*10000000 + int64(utcTime.Nanosecond()/100)
|
||||||
|
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeUUID generates a new time based UUID (version 1) using the current
|
||||||
|
// time as the timestamp.
|
||||||
|
func TimeUUID() UUID {
|
||||||
|
return UUIDFromTime(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// The min and max clock values for a UUID.
|
||||||
|
//
|
||||||
|
// Cassandra's TimeUUIDType compares the lsb parts as signed byte arrays.
|
||||||
|
// Thus, the min value for each byte is -128 and the max is +127.
|
||||||
|
const (
|
||||||
|
minClock = 0x8080
|
||||||
|
maxClock = 0x7f7f
|
||||||
|
)
|
||||||
|
|
||||||
|
// The min and max node values for a UUID.
|
||||||
|
//
|
||||||
|
// See explanation about Cassandra's TimeUUIDType comparison logic above.
|
||||||
|
var (
|
||||||
|
minNode = []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80}
|
||||||
|
maxNode = []byte{0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f}
|
||||||
|
)
|
||||||
|
|
||||||
|
// MinTimeUUID generates a "fake" time based UUID (version 1) which will be
|
||||||
|
// the smallest possible UUID generated for the provided timestamp.
|
||||||
|
//
|
||||||
|
// UUIDs generated by this function are not unique and are mostly suitable only
|
||||||
|
// in queries to select a time range of a Cassandra's TimeUUID column.
|
||||||
|
func MinTimeUUID(t time.Time) UUID {
|
||||||
|
return TimeUUIDWith(getTimestamp(t), minClock, minNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxTimeUUID generates a "fake" time based UUID (version 1) which will be
|
||||||
|
// the biggest possible UUID generated for the provided timestamp.
|
||||||
|
//
|
||||||
|
// UUIDs generated by this function are not unique and are mostly suitable only
|
||||||
|
// in queries to select a time range of a Cassandra's TimeUUID column.
|
||||||
|
func MaxTimeUUID(t time.Time) UUID {
|
||||||
|
return TimeUUIDWith(getTimestamp(t), maxClock, maxNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UUIDFromTime generates a new time based UUID (version 1) as described in
|
||||||
|
// RFC 4122. This UUID contains the MAC address of the node that generated
|
||||||
|
// the UUID, the given timestamp and a sequence number.
|
||||||
|
func UUIDFromTime(t time.Time) UUID {
|
||||||
|
ts := getTimestamp(t)
|
||||||
|
clock := atomic.AddUint32(&clockSeq, 1)
|
||||||
|
|
||||||
|
return TimeUUIDWith(ts, clock, hardwareAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeUUIDWith generates a new time based UUID (version 1) as described in
|
||||||
|
// RFC4122 with given parameters. t is the number of 100's of nanoseconds
|
||||||
|
// since 15 Oct 1582 (60bits). clock is the number of clock sequence (14bits).
|
||||||
|
// node is a slice to gurarantee the uniqueness of the UUID (up to 6bytes).
|
||||||
|
// Note: calling this function does not increment the static clock sequence.
|
||||||
|
func TimeUUIDWith(t int64, clock uint32, node []byte) UUID {
|
||||||
|
var u UUID
|
||||||
|
|
||||||
|
u[0], u[1], u[2], u[3] = byte(t>>24), byte(t>>16), byte(t>>8), byte(t)
|
||||||
|
u[4], u[5] = byte(t>>40), byte(t>>32)
|
||||||
|
u[6], u[7] = byte(t>>56)&0x0F, byte(t>>48)
|
||||||
|
|
||||||
|
u[8] = byte(clock >> 8)
|
||||||
|
u[9] = byte(clock)
|
||||||
|
|
||||||
|
copy(u[10:], node)
|
||||||
|
|
||||||
|
u[6] |= 0x10 // set version to 1 (time based uuid)
|
||||||
|
u[8] &= 0x3F // clear variant
|
||||||
|
u[8] |= 0x80 // set to IETF variant
|
||||||
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the UUID in it's canonical form, a 32 digit hexadecimal
|
||||||
|
// number in the form of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx.
|
||||||
|
func (u UUID) String() string {
|
||||||
|
var offsets = [...]int{0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34}
|
||||||
|
const hexString = "0123456789abcdef"
|
||||||
|
r := make([]byte, 36)
|
||||||
|
for i, b := range u {
|
||||||
|
r[offsets[i]] = hexString[b>>4]
|
||||||
|
r[offsets[i]+1] = hexString[b&0xF]
|
||||||
|
}
|
||||||
|
r[8] = '-'
|
||||||
|
r[13] = '-'
|
||||||
|
r[18] = '-'
|
||||||
|
r[23] = '-'
|
||||||
|
return string(r)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns the raw byte slice for this UUID. A UUID is always 128 bits
|
||||||
|
// (16 bytes) long.
|
||||||
|
func (u UUID) Bytes() []byte {
|
||||||
|
return u[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variant returns the variant of this UUID. This package will only generate
|
||||||
|
// UUIDs in the IETF variant.
|
||||||
|
func (u UUID) Variant() int {
|
||||||
|
x := u[8]
|
||||||
|
if x&0x80 == 0 {
|
||||||
|
return VariantNCSCompat
|
||||||
|
}
|
||||||
|
if x&0x40 == 0 {
|
||||||
|
return VariantIETF
|
||||||
|
}
|
||||||
|
if x&0x20 == 0 {
|
||||||
|
return VariantMicrosoft
|
||||||
|
}
|
||||||
|
return VariantFuture
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version extracts the version of this UUID variant. The RFC 4122 describes
|
||||||
|
// five kinds of UUIDs.
|
||||||
|
func (u UUID) Version() int {
|
||||||
|
return int(u[6] & 0xF0 >> 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Node extracts the MAC address of the node who generated this UUID. It will
|
||||||
|
// return nil if the UUID is not a time based UUID (version 1).
|
||||||
|
func (u UUID) Node() []byte {
|
||||||
|
if u.Version() != 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return u[10:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clock extracts the clock sequence of this UUID. It will return zero if the
|
||||||
|
// UUID is not a time based UUID (version 1).
|
||||||
|
func (u UUID) Clock() uint32 {
|
||||||
|
if u.Version() != 1 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clock sequence is the lower 14bits of u[8:10]
|
||||||
|
return uint32(u[8]&0x3F)<<8 | uint32(u[9])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timestamp extracts the timestamp information from a time based UUID
|
||||||
|
// (version 1).
|
||||||
|
func (u UUID) Timestamp() int64 {
|
||||||
|
if u.Version() != 1 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int64(uint64(u[0])<<24|uint64(u[1])<<16|
|
||||||
|
uint64(u[2])<<8|uint64(u[3])) +
|
||||||
|
int64(uint64(u[4])<<40|uint64(u[5])<<32) +
|
||||||
|
int64(uint64(u[6]&0x0F)<<56|uint64(u[7])<<48)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time is like Timestamp, except that it returns a time.Time.
|
||||||
|
func (u UUID) Time() time.Time {
|
||||||
|
if u.Version() != 1 {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
t := u.Timestamp()
|
||||||
|
sec := t / 1e7
|
||||||
|
nsec := (t % 1e7) * 100
|
||||||
|
return time.Unix(sec+timeBase, nsec).UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshaling for JSON
|
||||||
|
func (u UUID) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(`"` + u.String() + `"`), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshaling for JSON
|
||||||
|
func (u *UUID) UnmarshalJSON(data []byte) error {
|
||||||
|
str := strings.Trim(string(data), `"`)
|
||||||
|
if len(str) > 36 {
|
||||||
|
return fmt.Errorf("invalid JSON UUID %s", str)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := ParseUUID(str)
|
||||||
|
if err == nil {
|
||||||
|
copy(u[:], parsed[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UUID) MarshalText() ([]byte, error) {
|
||||||
|
return []byte(u.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UUID) UnmarshalText(text []byte) (err error) {
|
||||||
|
*u, err = ParseUUID(string(text))
|
||||||
|
return
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue