Consume GoogleSignIn.validateJWT
TL;DR: - Consume GoogleSignIn.validateJWT in the Handler for /verify - Rename validation fn to validateJWT - Prefer Text to String type
This commit is contained in:
		
							parent
							
								
									8a7a3b29a9
								
							
						
					
					
						commit
						e8f35f0d10
					
				
					 4 changed files with 60 additions and 26 deletions
				
			
		|  | @ -3,7 +3,7 @@ | |||
| module GoogleSignIn where | ||||
| -------------------------------------------------------------------------------- | ||||
| import Data.String.Conversions (cs) | ||||
| import Data.Text (Text) | ||||
| import Data.Text | ||||
| import Web.JWT | ||||
| import Utils | ||||
| 
 | ||||
|  | @ -14,10 +14,16 @@ import qualified Data.Time.Clock.POSIX as POSIX | |||
| -------------------------------------------------------------------------------- | ||||
| 
 | ||||
| newtype EncodedJWT = EncodedJWT Text | ||||
|   deriving (Show) | ||||
| 
 | ||||
| newtype DecodedJWT = DecodedJWT (JWT UnverifiedJWT) | ||||
|   deriving (Show) | ||||
| 
 | ||||
| instance Eq DecodedJWT where | ||||
|   (DecodedJWT _) == (DecodedJWT _) = True | ||||
| 
 | ||||
| -- | Some of the errors that a JWT | ||||
| data ValidationResult | ||||
|   = Valid | ||||
|   = Valid DecodedJWT | ||||
|   | DecodeError | ||||
|   | GoogleSaysInvalid Text | ||||
|   | NoMatchingClientIDs [StringOrURI] | ||||
|  | @ -36,10 +42,10 @@ data ValidationResult | |||
| -- * The `exp` time has not passed | ||||
| -- | ||||
| -- Set `skipHTTP` to `True` to avoid making the network request for testing. | ||||
| jwtIsValid :: Bool | ||||
| validateJWT :: Bool | ||||
|            -> EncodedJWT | ||||
|            -> IO ValidationResult | ||||
| jwtIsValid skipHTTP (EncodedJWT encodedJWT) = do | ||||
| validateJWT skipHTTP (EncodedJWT encodedJWT) = do | ||||
|   case encodedJWT |> decode of | ||||
|     Nothing -> pure DecodeError | ||||
|     Just jwt -> do | ||||
|  | @ -91,4 +97,16 @@ jwtIsValid skipHTTP (EncodedJWT encodedJWT) = do | |||
|                       if not $ currentTime <= jwtExpiry then | ||||
|                         pure $ StaleExpiry jwtExpiry | ||||
|                       else | ||||
|                         pure Valid | ||||
|                         pure $ jwt |> DecodedJWT |> Valid | ||||
| 
 | ||||
| -- | Attempt to explain the `ValidationResult` to a human. | ||||
| explainResult :: ValidationResult -> String | ||||
| explainResult (Valid _) = "Everything appears to be valid" | ||||
| explainResult DecodeError = "We had difficulty decoding the provided JWT" | ||||
| explainResult (GoogleSaysInvalid x) = "After checking with Google, they claimed that the provided JWT was invalid: " ++ cs x | ||||
| explainResult (NoMatchingClientIDs audFields) = "None of the values in the `aud` field on the provided JWT match our client ID: " ++ show audFields | ||||
| explainResult (WrongIssuer issuer) = "The `iss` field in the provided JWT does not match what we expect: " ++ show issuer | ||||
| explainResult (StringOrURIParseFailure x) = "We had difficulty parsing values as URIs" ++ show x | ||||
| explainResult TimeConversionFailure = "We had difficulty converting the current time to a value we can use to compare with the JWT's `exp` field" | ||||
| explainResult (MissingRequiredClaim claim) = "Your JWT is missing the following claim: " ++ cs claim | ||||
| explainResult (StaleExpiry x) = "The `exp` field on your JWT has expired" ++ x |> show |> cs | ||||
|  |  | |||
|  | @ -7,10 +7,14 @@ module Main where | |||
| import Servant | ||||
| import API | ||||
| import Control.Monad.IO.Class (liftIO) | ||||
| import GoogleSignIn (EncodedJWT(..), ValidationResult(..)) | ||||
| import Data.String.Conversions (cs) | ||||
| import Utils | ||||
| 
 | ||||
| import qualified Network.Wai.Handler.Warp as Warp | ||||
| import qualified Network.Wai.Middleware.Cors as Cors | ||||
| import qualified Types as T | ||||
| import qualified GoogleSignIn | ||||
| -------------------------------------------------------------------------------- | ||||
| 
 | ||||
| server :: Server API | ||||
|  | @ -18,8 +22,13 @@ server = verifyGoogleSignIn | |||
|   where | ||||
|     verifyGoogleSignIn :: T.VerifyGoogleSignInRequest -> Handler NoContent | ||||
|     verifyGoogleSignIn T.VerifyGoogleSignInRequest{..} = do | ||||
|       liftIO $ putStrLn $ "Received: " ++ idToken | ||||
|     validationResult <- liftIO $ GoogleSignIn.validateJWT False (EncodedJWT idToken) | ||||
|     case validationResult of | ||||
|       Valid _ -> do | ||||
|         liftIO $ putStrLn "Sign-in valid! Let's create a session" | ||||
|         pure NoContent | ||||
|       err -> do | ||||
|         throwError err401 { errBody = err |> GoogleSignIn.explainResult |> cs } | ||||
| 
 | ||||
| main :: IO () | ||||
| main = do | ||||
|  |  | |||
|  | @ -4,8 +4,8 @@ module Spec where | |||
| -------------------------------------------------------------------------------- | ||||
| import Test.Hspec | ||||
| import Utils | ||||
| import Web.JWT (numericDate) | ||||
| import GoogleSignIn (ValidationResult(..)) | ||||
| import Web.JWT (numericDate, decode) | ||||
| import GoogleSignIn (EncodedJWT(..), DecodedJWT(..), ValidationResult(..)) | ||||
| 
 | ||||
| import qualified GoogleSignIn | ||||
| import qualified Fixtures as F | ||||
|  | @ -16,36 +16,40 @@ import qualified Data.Time.Clock.POSIX as POSIX | |||
| main :: IO () | ||||
| main = hspec $ do | ||||
|   describe "GoogleSignIn" $ | ||||
|     describe "jwtIsValid" $ do | ||||
|       let jwtIsValid' = GoogleSignIn.jwtIsValid True | ||||
|     describe "validateJWT" $ do | ||||
|       let validateJWT' = GoogleSignIn.validateJWT True | ||||
|       it "returns a decode error when an incorrectly encoded JWT is used" $ do | ||||
|         jwtIsValid' (GoogleSignIn.EncodedJWT "rubbish") `shouldReturn` DecodeError | ||||
|         validateJWT' (GoogleSignIn.EncodedJWT "rubbish") `shouldReturn` DecodeError | ||||
| 
 | ||||
|       it "returns validation error when the aud field doesn't match my client ID" $ do | ||||
|         let auds = ["wrong-client-id"] | ||||
|                    |> fmap TestUtils.unsafeStringOrURI | ||||
|             encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds } | ||||
|                          |> F.googleJWT | ||||
|         jwtIsValid' encodedJWT `shouldReturn` NoMatchingClientIDs auds | ||||
|         validateJWT' encodedJWT `shouldReturn` NoMatchingClientIDs auds | ||||
| 
 | ||||
|       it "returns validation success when one of the aud fields matches my client ID" $ do | ||||
|         let auds = ["wrong-client-id", "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"] | ||||
|                    |> fmap TestUtils.unsafeStringOrURI | ||||
|             encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds } | ||||
|             encodedJWT@(EncodedJWT jwt) = | ||||
|               F.defaultJWTFields { F.overwriteAuds = auds } | ||||
|               |> F.googleJWT | ||||
|         jwtIsValid' encodedJWT `shouldReturn` Valid | ||||
|             decodedJWT = jwt |> decode |> TestUtils.unsafeJust |> DecodedJWT | ||||
|         validateJWT' encodedJWT `shouldReturn` Valid decodedJWT | ||||
| 
 | ||||
|       it "returns validation error when one of the iss field doesn't match accounts.google.com or https://accounts.google.com" $ do | ||||
|         let erroneousIssuer = TestUtils.unsafeStringOrURI "not-accounts.google.com" | ||||
|             encodedJWT = F.defaultJWTFields { F.overwriteIss = erroneousIssuer } | ||||
|                          |> F.googleJWT | ||||
|         jwtIsValid' encodedJWT `shouldReturn` WrongIssuer erroneousIssuer | ||||
|         validateJWT' encodedJWT `shouldReturn` WrongIssuer erroneousIssuer | ||||
| 
 | ||||
|       it "returns validation success when the iss field matches accounts.google.com or https://accounts.google.com" $ do | ||||
|         let erroneousIssuer = TestUtils.unsafeStringOrURI "https://accounts.google.com" | ||||
|             encodedJWT = F.defaultJWTFields { F.overwriteIss = erroneousIssuer } | ||||
|             encodedJWT@(EncodedJWT jwt) = | ||||
|               F.defaultJWTFields { F.overwriteIss = erroneousIssuer } | ||||
|               |> F.googleJWT | ||||
|         jwtIsValid' encodedJWT `shouldReturn` Valid | ||||
|             decodedJWT = jwt |> decode |> TestUtils.unsafeJust |> DecodedJWT | ||||
|         validateJWT' encodedJWT `shouldReturn` Valid decodedJWT | ||||
| 
 | ||||
|       it "fails validation when the exp field has expired" $ do | ||||
|         let mErroneousExp = numericDate 0 | ||||
|  | @ -54,7 +58,7 @@ main = hspec $ do | |||
|           Just erroneousExp -> do | ||||
|             let encodedJWT = F.defaultJWTFields { F.overwriteExp = erroneousExp } | ||||
|                              |> F.googleJWT | ||||
|             jwtIsValid' encodedJWT `shouldReturn` StaleExpiry erroneousExp | ||||
|             validateJWT' encodedJWT `shouldReturn` StaleExpiry erroneousExp | ||||
| 
 | ||||
|       it "passes validation when the exp field is current" $ do | ||||
|         mFreshExp <- POSIX.getPOSIXTime | ||||
|  | @ -63,6 +67,8 @@ main = hspec $ do | |||
|         case mFreshExp of | ||||
|           Nothing -> True `shouldBe` False | ||||
|           Just freshExp -> do | ||||
|             let encodedJWT = F.defaultJWTFields { F.overwriteExp = freshExp } | ||||
|             let encodedJWT@(EncodedJWT jwt) = | ||||
|                   F.defaultJWTFields { F.overwriteExp = freshExp } | ||||
|                   |> F.googleJWT | ||||
|             jwtIsValid' encodedJWT `shouldReturn` Valid | ||||
|                 decodedJWT = jwt |> decode |> TestUtils.unsafeJust |> DecodedJWT | ||||
|             validateJWT' encodedJWT `shouldReturn` Valid decodedJWT | ||||
|  |  | |||
|  | @ -4,10 +4,11 @@ | |||
| module Types where | ||||
| -------------------------------------------------------------------------------- | ||||
| import Data.Aeson | ||||
| import Data.Text | ||||
| -------------------------------------------------------------------------------- | ||||
| 
 | ||||
| data VerifyGoogleSignInRequest = VerifyGoogleSignInRequest | ||||
|   { idToken :: String | ||||
|   { idToken :: Text | ||||
|   } deriving (Eq, Show) | ||||
| 
 | ||||
| instance FromJSON VerifyGoogleSignInRequest where | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue