{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
------------------------------------------------------------------------------
-- | This is a support module meant to back all session back-end
-- implementations.
--
-- It gives us an encrypted and timestamped cookie that can store an arbitrary
-- serializable payload. For security, it will:
--
--   * Encrypt its payload together with a timestamp.
--
--   * Check the timestamp for session expiration everytime you read from the
--     cookie. This will limit intercept-and-replay attacks by disallowing
--     cookies older than the timeout threshold.

module Snap.Snaplet.Session.SecureCookie
       ( SecureCookie
       , getSecureCookie
       , setSecureCookie
       , expireSecureCookie
       -- ** Helper functions
       , encodeSecureCookie
       , decodeSecureCookie
       , checkTimeout
       ) where

------------------------------------------------------------------------------
import           Control.Monad
import           Control.Monad.Trans
import           Data.ByteString       (ByteString)
import           Data.Serialize
import           Data.Time
import           Data.Time.Clock.POSIX
import           Snap.Core
import           Web.ClientSession

#if !MIN_VERSION_base(4,8,0)
import           Control.Applicative
#endif

------------------------------------------------------------------------------
-- | Arbitrary payload with timestamp.
type SecureCookie t = (UTCTime, t)


------------------------------------------------------------------------------
-- | Get the cookie payload.
getSecureCookie :: (MonadSnap m, Serialize t)
                => ByteString       -- ^ Cookie name
                -> Key              -- ^ Encryption key
                -> Maybe Int        -- ^ Timeout in seconds
                -> m (Maybe t)
getSecureCookie :: ByteString -> Key -> Maybe Int -> m (Maybe t)
getSecureCookie name :: ByteString
name key :: Key
key timeout :: Maybe Int
timeout = do
    Maybe Cookie
rqCookie <- ByteString -> m (Maybe Cookie)
forall (m :: * -> *). MonadSnap m => ByteString -> m (Maybe Cookie)
getCookie ByteString
name
    Maybe Cookie
rspCookie <- ByteString -> Response -> Maybe Cookie
getResponseCookie ByteString
name (Response -> Maybe Cookie) -> m Response -> m (Maybe Cookie)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Response
forall (m :: * -> *). MonadSnap m => m Response
getResponse
    let ck :: Maybe Cookie
ck = Maybe Cookie
rspCookie Maybe Cookie -> Maybe Cookie -> Maybe Cookie
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Maybe Cookie
rqCookie
    let val :: Maybe (SecureCookie t)
val = (Cookie -> ByteString) -> Maybe Cookie -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Cookie -> ByteString
cookieValue Maybe Cookie
ck Maybe ByteString
-> (ByteString -> Maybe (SecureCookie t)) -> Maybe (SecureCookie t)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Key -> ByteString -> Maybe (SecureCookie t)
forall a.
Serialize a =>
Key -> ByteString -> Maybe (SecureCookie a)
decodeSecureCookie Key
key
    case Maybe (SecureCookie t)
val of
      Nothing -> Maybe t -> m (Maybe t)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe t
forall a. Maybe a
Nothing
      Just (ts :: UTCTime
ts, t :: t
t) -> do
          Bool
to <- Maybe Int -> UTCTime -> m Bool
forall (m :: * -> *). MonadSnap m => Maybe Int -> UTCTime -> m Bool
checkTimeout Maybe Int
timeout UTCTime
ts
          Maybe t -> m (Maybe t)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe t -> m (Maybe t)) -> Maybe t -> m (Maybe t)
forall a b. (a -> b) -> a -> b
$ case Bool
to of
            True -> Maybe t
forall a. Maybe a
Nothing
            False -> t -> Maybe t
forall a. a -> Maybe a
Just t
t


------------------------------------------------------------------------------
-- | Decode secure cookie payload wih key.
decodeSecureCookie  :: Serialize a
                     => Key                     -- ^ Encryption key
                     -> ByteString              -- ^ Encrypted payload
                     -> Maybe (SecureCookie a)
decodeSecureCookie :: Key -> ByteString -> Maybe (SecureCookie a)
decodeSecureCookie key :: Key
key value :: ByteString
value = do
    ByteString
cv <- Key -> ByteString -> Maybe ByteString
decrypt Key
key ByteString
value
    (i :: Integer
i, val :: a
val) <- (String -> Maybe (Integer, a))
-> ((Integer, a) -> Maybe (Integer, a))
-> Either String (Integer, a)
-> Maybe (Integer, a)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe (Integer, a) -> String -> Maybe (Integer, a)
forall a b. a -> b -> a
const Maybe (Integer, a)
forall a. Maybe a
Nothing) (Integer, a) -> Maybe (Integer, a)
forall a. a -> Maybe a
Just (Either String (Integer, a) -> Maybe (Integer, a))
-> Either String (Integer, a) -> Maybe (Integer, a)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String (Integer, a)
forall a. Serialize a => ByteString -> Either String a
decode ByteString
cv
    SecureCookie a -> Maybe (SecureCookie a)
forall (m :: * -> *) a. Monad m => a -> m a
return (SecureCookie a -> Maybe (SecureCookie a))
-> SecureCookie a -> Maybe (SecureCookie a)
forall a b. (a -> b) -> a -> b
$ (POSIXTime -> UTCTime
posixSecondsToUTCTime (Integer -> POSIXTime
forall a. Num a => Integer -> a
fromInteger Integer
i), a
val)


------------------------------------------------------------------------------
-- | Inject the payload.
setSecureCookie :: (MonadSnap m, Serialize t)
                => ByteString       -- ^ Cookie name
                -> Maybe ByteString -- ^ Cookie domain
                -> Key              -- ^ Encryption key
                -> Maybe Int        -- ^ Max age in seconds
                -> t                -- ^ Serializable payload
                -> m ()
setSecureCookie :: ByteString -> Maybe ByteString -> Key -> Maybe Int -> t -> m ()
setSecureCookie name :: ByteString
name domain :: Maybe ByteString
domain key :: Key
key to :: Maybe Int
to val :: t
val = do
    UTCTime
t <- IO UTCTime -> m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    ByteString
val' <- Key -> SecureCookie t -> m ByteString
forall (m :: * -> *) t.
(MonadIO m, Serialize t) =>
Key -> SecureCookie t -> m ByteString
encodeSecureCookie Key
key (UTCTime
t, t
val)
    let expire :: Maybe UTCTime
expire = Maybe Int
to Maybe Int -> (Int -> Maybe UTCTime) -> Maybe UTCTime
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= UTCTime -> Maybe UTCTime
forall a. a -> Maybe a
Just (UTCTime -> Maybe UTCTime)
-> (Int -> UTCTime) -> Int -> Maybe UTCTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (POSIXTime -> UTCTime -> UTCTime)
-> UTCTime -> POSIXTime -> UTCTime
forall a b c. (a -> b -> c) -> b -> a -> c
flip POSIXTime -> UTCTime -> UTCTime
addUTCTime UTCTime
t (POSIXTime -> UTCTime) -> (Int -> POSIXTime) -> Int -> UTCTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> POSIXTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    let nc :: Cookie
nc = ByteString
-> ByteString
-> Maybe UTCTime
-> Maybe ByteString
-> Maybe ByteString
-> Bool
-> Bool
-> Cookie
Cookie ByteString
name ByteString
val' Maybe UTCTime
expire Maybe ByteString
domain (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just "/") Bool
False Bool
True
    (Response -> Response) -> m ()
forall (m :: * -> *). MonadSnap m => (Response -> Response) -> m ()
modifyResponse ((Response -> Response) -> m ()) -> (Response -> Response) -> m ()
forall a b. (a -> b) -> a -> b
$ Cookie -> Response -> Response
addResponseCookie Cookie
nc


------------------------------------------------------------------------------
-- | Encode SecureCookie with key into injectable payload
encodeSecureCookie :: (MonadIO m, Serialize t)
                    => Key            -- ^ Encryption key
                    -> SecureCookie t -- ^ Payload
                    -> m ByteString
encodeSecureCookie :: Key -> SecureCookie t -> m ByteString
encodeSecureCookie key :: Key
key (t :: UTCTime
t, val :: t
val) =
    IO ByteString -> m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Key -> ByteString -> IO ByteString
encryptIO Key
key (ByteString -> IO ByteString)
-> ((Integer, t) -> ByteString) -> (Integer, t) -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, t) -> ByteString
forall a. Serialize a => a -> ByteString
encode ((Integer, t) -> IO ByteString) -> (Integer, t) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ (Integer
seconds, t
val)
  where
    seconds :: Integer
seconds = POSIXTime -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round (UTCTime -> POSIXTime
utcTimeToPOSIXSeconds UTCTime
t) :: Integer


------------------------------------------------------------------------------
-- | Expire secure cookie
expireSecureCookie :: MonadSnap m
                   => ByteString       -- ^ Cookie name
                   -> Maybe ByteString -- ^ Cookie domain
                   -> m ()
expireSecureCookie :: ByteString -> Maybe ByteString -> m ()
expireSecureCookie name :: ByteString
name domain :: Maybe ByteString
domain = Cookie -> m ()
forall (m :: * -> *). MonadSnap m => Cookie -> m ()
expireCookie Cookie
cookie
  where
    cookie :: Cookie
cookie = ByteString
-> ByteString
-> Maybe UTCTime
-> Maybe ByteString
-> Maybe ByteString
-> Bool
-> Bool
-> Cookie
Cookie ByteString
name "" Maybe UTCTime
forall a. Maybe a
Nothing Maybe ByteString
domain (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just "/") Bool
False Bool
False


------------------------------------------------------------------------------
-- | Validate session against timeout policy.
--
-- * If timeout is set to 'Nothing', never trigger a time-out.
--
-- * Otherwise, do a regular time-out check based on current time and given
--   timestamp.
checkTimeout :: (MonadSnap m) => Maybe Int -> UTCTime -> m Bool
checkTimeout :: Maybe Int -> UTCTime -> m Bool
checkTimeout Nothing _ = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
checkTimeout (Just x :: Int
x) t0 :: UTCTime
t0 = do
    UTCTime
t1 <- IO UTCTime -> m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ UTCTime
t1 UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
> POSIXTime -> UTCTime -> UTCTime
addUTCTime (Int -> POSIXTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) UTCTime
t0