summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordiogob <>2020-10-16 21:27:00 (GMT)
committerhdiff <hdiff@hdiff.luite.com>2020-10-16 21:27:00 (GMT)
commit7a5fc3b475c3b5140ddf557e55bd4c2684982d61 (patch)
tree41e9437de7bda28d056fd68de097a0573e9311ec
parentba72ec86a4644be9126cc7fbaed95a4149f1f874 (diff)
version 0.10.0.0HEAD0.10.0.0master
-rw-r--r--postgres-websockets.cabal3
-rw-r--r--src/PostgresWebsockets.hs6
-rw-r--r--src/PostgresWebsockets/Claims.hs11
-rw-r--r--src/PostgresWebsockets/Config.hs2
-rw-r--r--src/PostgresWebsockets/Context.hs39
-rw-r--r--src/PostgresWebsockets/HasqlBroadcast.hs12
-rw-r--r--src/PostgresWebsockets/Middleware.hs110
-rw-r--r--src/PostgresWebsockets/Server.hs36
-rw-r--r--test/ServerSpec.hs44
9 files changed, 153 insertions, 110 deletions
diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal
index f0c5d12..9a854ec 100644
--- a/postgres-websockets.cabal
+++ b/postgres-websockets.cabal
@@ -1,5 +1,5 @@
name: postgres-websockets
-version: 0.9.0.0
+version: 0.10.0.0
synopsis: Middleware to map LISTEN/NOTIFY messages to Websockets
description: Please see README.md
homepage: https://github.com/diogob/postgres-websockets#readme
@@ -25,6 +25,7 @@ library
other-modules: Paths_postgres_websockets
, PostgresWebsockets.Server
, PostgresWebsockets.Middleware
+ , PostgresWebsockets.Context
build-depends: base >= 4.7 && < 5
, hasql-pool >= 0.5 && < 0.6
, text >= 1.2 && < 1.3
diff --git a/src/PostgresWebsockets.hs b/src/PostgresWebsockets.hs
index 31af03b..f42e887 100644
--- a/src/PostgresWebsockets.hs
+++ b/src/PostgresWebsockets.hs
@@ -11,6 +11,6 @@ module PostgresWebsockets
, postgresWsMiddleware
) where
-import PostgresWebsockets.Middleware
-import PostgresWebsockets.Server
-import PostgresWebsockets.Config
+import PostgresWebsockets.Middleware ( postgresWsMiddleware )
+import PostgresWebsockets.Server ( serve )
+import PostgresWebsockets.Config ( prettyVersion, loadConfig )
diff --git a/src/PostgresWebsockets/Claims.hs b/src/PostgresWebsockets/Claims.hs
index 62b0fd6..2e7f0b3 100644
--- a/src/PostgresWebsockets/Claims.hs
+++ b/src/PostgresWebsockets/Claims.hs
@@ -11,15 +11,16 @@ module PostgresWebsockets.Claims
( ConnectionInfo,validateClaims
) where
-import Control.Lens
-import qualified Crypto.JOSE.Types as JOSE.Types
-import Crypto.JWT
-import qualified Data.HashMap.Strict as M
-import Protolude
+import Protolude
+import Control.Lens
+import Crypto.JWT
import Data.List
import Data.Time.Clock (UTCTime)
+import qualified Crypto.JOSE.Types as JOSE.Types
+import qualified Data.HashMap.Strict as M
import qualified Data.Aeson as JSON
+
type Claims = M.HashMap Text JSON.Value
type ConnectionInfo = ([ByteString], ByteString, Claims)
diff --git a/src/PostgresWebsockets/Config.hs b/src/PostgresWebsockets/Config.hs
index 66df5d6..e693d35 100644
--- a/src/PostgresWebsockets/Config.hs
+++ b/src/PostgresWebsockets/Config.hs
@@ -31,6 +31,7 @@ data AppConfig = AppConfig {
, configHost :: Text
, configPort :: Int
, configListenChannel :: Text
+ , configMetaChannel :: Maybe Text
, configJwtSecret :: ByteString
, configJwtSecretIsBase64 :: Bool
, configPool :: Int
@@ -68,6 +69,7 @@ readOptions =
<*> var str "PGWS_HOST" (def "*4" <> helpDef show <> help "Address the server will listen for websocket connections")
<*> var auto "PGWS_PORT" (def 3000 <> helpDef show <> help "Port the server will listen for websocket connections")
<*> var str "PGWS_LISTEN_CHANNEL" (def "postgres-websockets-listener" <> helpDef show <> help "Master channel used in the database to send or read messages in any notification channel")
+ <*> optional (var str "PGWS_META_CHANNEL" (help "Websockets channel used to send events about the server state changes."))
<*> var str "PGWS_JWT_SECRET" (help "Secret used to sign JWT tokens used to open communications channels")
<*> var auto "PGWS_JWT_SECRET_BASE64" (def False <> helpDef show <> help "Indicate whether the JWT secret should be decoded from a base64 encoded string")
<*> var auto "PGWS_POOL_SIZE" (def 10 <> helpDef show <> help "How many connection to the database should be used by the connection pool")
diff --git a/src/PostgresWebsockets/Context.hs b/src/PostgresWebsockets/Context.hs
new file mode 100644
index 0000000..89e352c
--- /dev/null
+++ b/src/PostgresWebsockets/Context.hs
@@ -0,0 +1,39 @@
+{-|
+Module : PostgresWebsockets.Context
+Description : Produce a context capable of running postgres-websockets sessions
+-}
+module PostgresWebsockets.Context
+ ( Context (..)
+ , mkContext
+ ) where
+
+import Protolude
+import Data.Time.Clock (UTCTime, getCurrentTime)
+import Control.AutoUpdate ( defaultUpdateSettings
+ , mkAutoUpdate
+ , updateAction
+ )
+import qualified Hasql.Pool as P
+
+import PostgresWebsockets.Config ( AppConfig(..) )
+import PostgresWebsockets.HasqlBroadcast (newHasqlBroadcaster)
+import PostgresWebsockets.Broadcast (Multiplexer)
+
+data Context = Context {
+ ctxConfig :: AppConfig
+ , ctxPool :: P.Pool
+ , ctxMulti :: Multiplexer
+ , ctxGetTime :: IO UTCTime
+ }
+
+-- | Given a configuration and a shutdown action (performed when the Multiplexer's listen connection dies) produces the context necessary to run sessions
+mkContext :: AppConfig -> IO () -> IO Context
+mkContext conf@AppConfig{..} shutdown = do
+ Context conf
+ <$> P.acquire (configPool, 10, pgSettings)
+ <*> newHasqlBroadcaster shutdown (toS configListenChannel) configRetries pgSettings
+ <*> mkGetTime
+ where
+ mkGetTime :: IO (IO UTCTime)
+ mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime}
+ pgSettings = toS configDatabase
diff --git a/src/PostgresWebsockets/HasqlBroadcast.hs b/src/PostgresWebsockets/HasqlBroadcast.hs
index efb34ba..275ea8f 100644
--- a/src/PostgresWebsockets/HasqlBroadcast.hs
+++ b/src/PostgresWebsockets/HasqlBroadcast.hs
@@ -19,11 +19,11 @@ import Protolude hiding (putErrLn)
import Hasql.Connection
import Hasql.Notifications
-import Data.Aeson (decode, Value(..))
-import Data.HashMap.Lazy (lookupDefault)
+import Data.Aeson (decode, Value(..))
+import Data.HashMap.Lazy (lookupDefault)
import Data.Either.Combinators (mapBoth)
-import Data.Function (id)
-import Control.Retry (RetryStatus(..), retrying, capDelay, exponentialBackoff)
+import Data.Function (id)
+import Control.Retry (RetryStatus(..), retrying, capDelay, exponentialBackoff)
import PostgresWebsockets.Broadcast
@@ -99,11 +99,11 @@ newHasqlBroadcasterForChannel onConnectionFailure ch getCon = do
_ -> d
lookupStringDef _ d _ = d
channelDef = lookupStringDef "channel"
- openProducer msgs = do
+ openProducer msgQ = do
con <- getCon
listen con $ toPgIdentifier ch
waitForNotifications
- (\c m-> atomically $ writeTQueue msgs $ toMsg c m)
+ (\c m-> atomically $ writeTQueue msgQ $ toMsg c m)
con
putErrLn :: Text -> IO ()
diff --git a/src/PostgresWebsockets/Middleware.hs b/src/PostgresWebsockets/Middleware.hs
index 8444b27..f030183 100644
--- a/src/PostgresWebsockets/Middleware.hs
+++ b/src/PostgresWebsockets/Middleware.hs
@@ -10,39 +10,48 @@ module PostgresWebsockets.Middleware
( postgresWsMiddleware
) where
-import qualified Hasql.Pool as H
-import qualified Hasql.Notifications as H
-import qualified Network.Wai as Wai
+import Protolude
+import Data.Time.Clock (UTCTime)
+import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
+import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
+import qualified Hasql.Notifications as H
+import qualified Hasql.Pool as H
+import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
-import qualified Network.WebSockets as WS
-import Protolude
-
-import qualified Data.Aeson as A
-import qualified Data.ByteString.Char8 as BS
-import qualified Data.ByteString.Lazy as BL
-import qualified Data.HashMap.Strict as M
-import qualified Data.Text.Encoding.Error as T
-import Data.Time.Clock (UTCTime)
-import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
-import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
-import PostgresWebsockets.Broadcast (Multiplexer, onMessage)
-import qualified PostgresWebsockets.Broadcast as B
-import PostgresWebsockets.Claims
+import qualified Network.WebSockets as WS
+
+import qualified Data.Aeson as A
+import qualified Data.ByteString.Char8 as BS
+import qualified Data.ByteString.Lazy as BL
+import qualified Data.HashMap.Strict as M
+import qualified Data.Text.Encoding.Error as T
+
+import PostgresWebsockets.Broadcast (onMessage)
+import PostgresWebsockets.Claims ( ConnectionInfo, validateClaims )
+import PostgresWebsockets.Context ( Context(..) )
+import PostgresWebsockets.Config (AppConfig(..))
+import qualified PostgresWebsockets.Broadcast as B
+
+
+data Event =
+ WebsocketMessage
+ | ConnectionOpen
+ deriving (Show, Eq, Generic)
data Message = Message
{ claims :: A.Object
- , channel :: Text
+ , event :: Event
, payload :: Text
+ , channel :: Text
} deriving (Show, Eq, Generic)
+instance A.ToJSON Event
instance A.ToJSON Message
-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
-postgresWsMiddleware :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
+postgresWsMiddleware :: Context -> Wai.Middleware
postgresWsMiddleware =
- WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
- where
- compose = (.) . (.) . (.) . (.) . (.)
+ WS.websocketsOr WS.defaultConnectionOptions . wsApp
-- private functions
jwtExpirationStatusCode :: Word16
@@ -50,9 +59,9 @@ jwtExpirationStatusCode = 3001
-- when the websocket is closed a ConnectionClosed Exception is triggered
-- this kills all children and frees resources for us
-wsApp :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
-wsApp getTime dbChannel secret pool multi pendingConn =
- getTime >>= validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
+wsApp :: Context -> WS.ServerApp
+wsApp Context{..} pendingConn =
+ ctxGetTime >>= validateClaims requestChannel (configJwtSecret ctxConfig) (toS jwtToken) >>= either rejectRequest forkSessions
where
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
@@ -85,12 +94,21 @@ wsApp getTime dbChannel secret pool multi pendingConn =
Just _ -> pure ()
Nothing -> pure ()
+ let sendNotification msg channel = sendMessageWithTimestamp $ websocketMessageForChannel msg channel
+ sendMessageToDatabase = sendToDatabase ctxPool (configListenChannel ctxConfig)
+ sendMessageWithTimestamp = timestampMessage ctxGetTime >=> sendMessageToDatabase
+ websocketMessageForChannel = Message validClaims WebsocketMessage
+ connectionOpenMessage = Message validClaims ConnectionOpen
+
+ case configMetaChannel ctxConfig of
+ Nothing -> pure ()
+ Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (toS $ BS.intercalate "," chs) ch
+
when (hasRead mode) $
- forM_ chs $ flip (onMessage multi) $ WS.sendTextData conn . B.payload
+ forM_ chs $ flip (onMessage ctxMulti) $ WS.sendTextData conn . B.payload
when (hasWrite mode) $
- let sendNotifications = void . H.notifyPool pool dbChannel . toS
- in notifySession validClaims conn getTime sendNotifications chs
+ notifySession conn sendNotification chs
waitForever <- newEmptyMVar
void $ takeMVar waitForever
@@ -98,30 +116,22 @@ wsApp getTime dbChannel secret pool multi pendingConn =
-- Having both channel and claims as parameters seem redundant
-- But it allows the function to ignore the claims structure and the source
-- of the channel, so all claims decoding can be coded in the caller
-notifySession :: A.Object
- -> WS.Connection
- -> IO UTCTime
- -> (ByteString -> IO ())
- -> [ByteString]
- -> IO ()
-notifySession claimsToSend wsCon getTime send chs =
+notifySession :: WS.Connection -> (Text -> Text -> IO ()) -> [ByteString] -> IO ()
+notifySession wsCon sendToChannel chs =
withAsync (forever relayData) wait
where
- relayData = do
+ relayData = do
msg <- WS.receiveData wsCon
- forM_ chs (relayChannelData msg . toS)
+ forM_ chs (sendToChannel msg . toS)
- relayChannelData msg ch = do
- claims' <- claimsWithTime ch
- send $ jsonMsg ch claims' msg
-
- -- we need to decode the bytestring to re-encode valid JSON for the notification
- jsonMsg :: Text -> M.HashMap Text A.Value -> ByteString -> ByteString
- jsonMsg ch cl = BL.toStrict . A.encode . Message cl ch . decodeUtf8With T.lenientDecode
-
- claimsWithTime :: Text -> IO (M.HashMap Text A.Value)
- claimsWithTime ch = do
- time <- utcTimeToPOSIXSeconds <$> getTime
- return $ M.insert "message_delivered_at" (A.Number $ realToFrac time) (claimsWithChannel ch)
+sendToDatabase :: H.Pool -> Text -> Message -> IO ()
+sendToDatabase pool dbChannel =
+ notify . jsonMsg
+ where
+ notify = void . H.notifyPool pool dbChannel . toS
+ jsonMsg = BL.toStrict . A.encode
- claimsWithChannel ch = M.insert "channel" (A.String ch) claimsToSend
+timestampMessage :: IO UTCTime -> Message -> IO Message
+timestampMessage getTime msg@Message{..} = do
+ time <- utcTimeToPOSIXSeconds <$> getTime
+ return $ msg{ claims = M.insert "message_delivered_at" (A.Number $ realToFrac time) claims}
diff --git a/src/PostgresWebsockets/Server.hs b/src/PostgresWebsockets/Server.hs
index b998d34..97b9945 100644
--- a/src/PostgresWebsockets/Server.hs
+++ b/src/PostgresWebsockets/Server.hs
@@ -6,47 +6,35 @@ module PostgresWebsockets.Server
( serve
) where
-import Protolude
-import PostgresWebsockets.Middleware
-import PostgresWebsockets.Config
-import PostgresWebsockets.HasqlBroadcast (newHasqlBroadcaster)
+import Protolude
+import Network.Wai.Application.Static ( staticApp, defaultFileServerSettings )
+import Network.Wai (Application, responseLBS)
+import Network.HTTP.Types (status200)
+import Network.Wai.Handler.Warp ( runSettings )
+import Network.Wai.Middleware.RequestLogger (logStdout)
-import qualified Hasql.Pool as P
-import Network.Wai.Application.Static
-import Data.Time.Clock (UTCTime, getCurrentTime)
-import Control.AutoUpdate ( defaultUpdateSettings
- , mkAutoUpdate
- , updateAction
- )
-import Network.Wai (Application, responseLBS)
-import Network.HTTP.Types (status200)
-import Network.Wai.Handler.Warp
-import Network.Wai.Middleware.RequestLogger (logStdout)
+import PostgresWebsockets.Middleware ( postgresWsMiddleware )
+import PostgresWebsockets.Config ( AppConfig(..), warpSettings )
+import PostgresWebsockets.Context ( mkContext )
-- | Start a stand-alone warp server using the parameters from AppConfig and a opening a database connection pool.
serve :: AppConfig -> IO ()
serve conf@AppConfig{..} = do
shutdownSignal <- newEmptyMVar
- let listenChannel = toS configListenChannel
- pgSettings = toS configDatabase
- waitForShutdown cl = void $ forkIO (takeMVar shutdownSignal >> cl)
+ let waitForShutdown cl = void $ forkIO (takeMVar shutdownSignal >> cl)
appSettings = warpSettings waitForShutdown conf
putStrLn $ ("Listening on port " :: Text) <> show configPort
let shutdown = putErrLn ("Broadcaster connection is dead" :: Text) >> putMVar shutdownSignal ()
- pool <- P.acquire (configPool, 10, pgSettings)
- multi <- newHasqlBroadcaster shutdown listenChannel configRetries pgSettings
- getTime <- mkGetTime
+ ctx <- mkContext conf shutdown
runSettings appSettings $
- postgresWsMiddleware getTime listenChannel configJwtSecret pool multi $
+ postgresWsMiddleware ctx $
logStdout $ maybe dummyApp staticApp' configPath
die "Shutting down server..."
where
- mkGetTime :: IO (IO UTCTime)
- mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime}
staticApp' :: Text -> Application
staticApp' = staticApp . defaultFileServerSettings . toS
dummyApp :: Application
diff --git a/test/ServerSpec.hs b/test/ServerSpec.hs
index 9c3fe5c..023dbe7 100644
--- a/test/ServerSpec.hs
+++ b/test/ServerSpec.hs
@@ -13,36 +13,38 @@ import qualified Network.WebSockets as WS
import Network.Socket (withSocketsDo)
testServerConfig :: AppConfig
-testServerConfig = AppConfig
+testServerConfig = AppConfig
{ configDatabase = "postgres://localhost/postgres"
, configPath = Nothing
, configHost = "*"
, configPort = 8080
, configListenChannel = "postgres-websockets-test-channel"
, configJwtSecret = "reallyreallyreallyreallyverysafe"
+ , configMetaChannel = Nothing
, configJwtSecretIsBase64 = False
, configPool = 10
+ , configRetries = 5
}
startTestServer :: IO ThreadId
startTestServer = do
threadId <- forkIO $ serve testServerConfig
- threadDelay 1000
+ threadDelay 500000
pure threadId
withServer :: IO () -> IO ()
withServer action =
bracket startTestServer
- killThread
+ (\tid -> killThread tid >> threadDelay 500000)
(const action)
sendWsData :: Text -> Text -> IO ()
sendWsData uri msg =
- withSocketsDo $
- WS.runClient
- "localhost"
- (configPort testServerConfig)
- (toS uri)
+ withSocketsDo $
+ WS.runClient
+ "localhost"
+ (configPort testServerConfig)
+ (toS uri)
(`WS.sendTextData` msg)
testChannel :: Text
@@ -58,27 +60,27 @@ waitForWsData :: Text -> IO (MVar ByteString)
waitForWsData uri = do
msg <- newEmptyMVar
void $ forkIO $
- withSocketsDo $
- WS.runClient
- "localhost"
- (configPort testServerConfig)
- (toS uri)
+ withSocketsDo $
+ WS.runClient
+ "localhost"
+ (configPort testServerConfig)
+ (toS uri)
(\c -> do
m <- WS.receiveData c
putMVar msg m
)
- threadDelay 1000
+ threadDelay 10000
pure msg
waitForMultipleWsData :: Int -> Text -> IO (MVar [ByteString])
waitForMultipleWsData messageCount uri = do
msg <- newEmptyMVar
void $ forkIO $
- withSocketsDo $
- WS.runClient
- "localhost"
- (configPort testServerConfig)
- (toS uri)
+ withSocketsDo $
+ WS.runClient
+ "localhost"
+ (configPort testServerConfig)
+ (toS uri)
(\c -> do
m <- replicateM messageCount (WS.receiveData c)
putMVar msg m
@@ -112,6 +114,6 @@ spec = around_ withServer $
sendWsData testAndSecondaryChannel "test data"
msgsJson <- takeMVar msgs
- forM_
- msgsJson
+ forM_
+ msgsJson
(\msgJson -> (msgJson ^? key "payload" . _String) `shouldBe` Just "test data")