diff --git a/Makefile b/Makefile index ddf538cd..4fc4dd22 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ SHARE_PROJECT_ROOT := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) export SHARE_PROJECT_ROOT UNAME := $(shell uname) -STACK_FLAGS := "--fast" +STACK_FLAGS := --fast dist_dir := $(shell stack path | awk '/^dist-dir/{print $$2}') exe_name := share-api exe := $(dist_dir)/build/$(exe_name)/$(exe_name) diff --git a/package.yaml b/package.yaml index f1b5caf2..9d2b58c0 100644 --- a/package.yaml +++ b/package.yaml @@ -97,6 +97,7 @@ dependencies: - parallel - parser-combinators - pem +- profunctors - hasql - hasql-pool - hasql-interpolate @@ -113,6 +114,7 @@ dependencies: - servant-client-core - servant-server - servant-conduit +- servant-websockets - serialise - stm - stm-chans @@ -153,6 +155,7 @@ dependencies: - wai-cors - wai-extra - wai-middleware-prometheus +- websockets - warp - witch - witherable diff --git a/share-api.cabal b/share-api.cabal index 099034bb..17fc896a 100644 --- a/share-api.cabal +++ b/share-api.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.37.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -196,10 +196,15 @@ library Share.Web.UCM.Sync.HashJWT Share.Web.UCM.Sync.Impl Share.Web.UCM.Sync.Types + Share.Web.UCM.SyncCommon.Impl + Share.Web.UCM.SyncCommon.Types Share.Web.UCM.SyncV2.API Share.Web.UCM.SyncV2.Impl Share.Web.UCM.SyncV2.Queries Share.Web.UCM.SyncV2.Types + Share.Web.UCM.SyncV3.API + Share.Web.UCM.SyncV3.Impl + Share.Web.UCM.SyncV3.Queries Share.Web.UI.Links Unison.Server.NameSearch.Postgres Unison.Server.Share.Definitions @@ -301,6 +306,7 @@ library , parallel , parser-combinators , pem + , profunctors , prometheus-client , prometheus-metrics-ghc , random @@ -314,6 +320,7 @@ library , servant-client-core , servant-conduit , servant-server + , servant-websockets , share-auth , share-utils , stm @@ -354,6 +361,7 @@ library , wai-extra , wai-middleware-prometheus , warp + , websockets , witch , witherable , x509 @@ -458,6 +466,7 @@ executable share-api , parallel , parser-combinators , pem + , profunctors , prometheus-client , prometheus-metrics-ghc , random @@ -471,6 +480,7 @@ executable share-api , servant-client-core , servant-conduit , servant-server + , servant-websockets , share-api , share-auth , share-utils @@ -512,6 +522,7 @@ executable share-api , wai-extra , wai-middleware-prometheus , warp + , websockets , witch , witherable , x509 diff --git a/src/Share/Postgres/Orphans.hs b/src/Share/Postgres/Orphans.hs index d9e38bd0..f0871c37 100644 --- a/src/Share/Postgres/Orphans.hs +++ b/src/Share/Postgres/Orphans.hs @@ -38,6 +38,7 @@ import Unison.Hash32 qualified as Hash32 import Unison.Name (Name) import Unison.NameSegment.Internal (NameSegment (..)) import Unison.SyncV2.Types (CBORBytes (..)) +import Unison.SyncV3.Types qualified as SyncV3 import Unison.Syntax.Name qualified as Name import UnliftIO (MonadUnliftIO (..)) @@ -286,3 +287,21 @@ instance MonadUnliftIO Hasql.Session where case res of Left e -> throwError e Right a -> pure a + +instance Hasql.DecodeValue SyncV3.EntityKind where + decodeValue = do + Decoders.enum \case + "causal" -> Just SyncV3.CausalEntity + "namespace" -> Just SyncV3.NamespaceEntity + "component" -> Just SyncV3.DefnComponentEntity + "patch" -> Just SyncV3.PatchEntity + _ -> Nothing + +instance Hasql.EncodeValue SyncV3.EntityKind where + encodeValue = Encoders.enum \case + SyncV3.CausalEntity -> "causal" + SyncV3.NamespaceEntity -> "namespace" + SyncV3.DefnComponentEntity -> "component" + SyncV3.PatchEntity -> "patch" + +deriving newtype instance Hasql.DecodeValue SyncV3.EntityDepth diff --git a/src/Share/Web/API.hs b/src/Share/Web/API.hs index b2f94dae..0be30f11 100644 --- a/src/Share/Web/API.hs +++ b/src/Share/Web/API.hs @@ -17,6 +17,7 @@ import Share.Web.Share.Users.API qualified as Users import Share.Web.Support.API qualified as Support import Share.Web.Types import Share.Web.UCM.SyncV2.API qualified as SyncV2 +import Share.Web.UCM.SyncV3.API qualified as SyncV3 import Unison.Share.API.Projects qualified as UCMProjects import Unison.Sync.API qualified as Unison.Sync @@ -54,6 +55,7 @@ type API = :<|> ("ucm" :> "v1" :> "sync" :> MaybeAuthenticatedSession :> Unison.Sync.API) :<|> ("ucm" :> "v1" :> "projects" :> MaybeAuthenticatedSession :> UCMProjects.ProjectsAPI) :<|> ("ucm" :> "v2" :> "sync" :> MaybeAuthenticatedUserId :> SyncV2.API) + :<|> ("ucm" :> "v3" :> "sync" :> MaybeAuthenticatedUserId :> SyncV3.API) :<|> ("admin" :> Admin.API) api :: Proxy API diff --git a/src/Share/Web/Impl.hs b/src/Share/Web/Impl.hs index 9d8156be..c65ad452 100644 --- a/src/Share/Web/Impl.hs +++ b/src/Share/Web/Impl.hs @@ -29,6 +29,7 @@ import Share.Web.Types import Share.Web.UCM.Projects.Impl qualified as UCMProjects import Share.Web.UCM.Sync.Impl qualified as Sync import Share.Web.UCM.SyncV2.Impl qualified as SyncV2 +import Share.Web.UCM.SyncV3.Impl qualified as SyncV3 import Share.Web.UI.Links qualified as Links discoveryEndpoint :: WebApp DiscoveryDocument @@ -90,4 +91,5 @@ server = :<|> Sync.server :<|> UCMProjects.server :<|> SyncV2.server + :<|> SyncV3.server :<|> Admin.server diff --git a/src/Share/Web/UCM/SyncCommon/Impl.hs b/src/Share/Web/UCM/SyncCommon/Impl.hs new file mode 100644 index 00000000..4d49ad2f --- /dev/null +++ b/src/Share/Web/UCM/SyncCommon/Impl.hs @@ -0,0 +1,55 @@ +module Share.Web.UCM.SyncCommon.Impl + ( parseBranchRef, + codebaseForBranchRef, + ) +where + +import Control.Monad.Except (ExceptT (ExceptT)) +import Servant +import Share.Codebase qualified as Codebase +import Share.IDs (ProjectBranchShortHand (..), ProjectReleaseShortHand (..), ProjectShortHand (..)) +import Share.IDs qualified as IDs +import Share.Postgres qualified as PG +import Share.Postgres.Queries qualified as PGQ +import Share.Postgres.Users.Queries qualified as UserQ +import Share.Prelude +import Share.Project (Project (..)) +import Share.User (User (..)) +import Share.Web.App +import Share.Web.Authorization qualified as AuthZ +import Share.Web.UCM.SyncCommon.Types +import U.Codebase.Sqlite.Orphans () +import Unison.SyncV2.Types qualified as SyncV2 + +parseBranchRef :: SyncV2.BranchRef -> Either Text (Either ProjectReleaseShortHand ProjectBranchShortHand) +parseBranchRef (SyncV2.BranchRef branchRef) = + case parseRelease <|> parseBranch of + Just a -> Right a + Nothing -> Left $ "Invalid repo info: " <> branchRef + where + parseBranch :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand) + parseBranch = fmap Right . eitherToMaybe $ IDs.fromText @ProjectBranchShortHand branchRef + parseRelease :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand) + parseRelease = fmap Left . eitherToMaybe $ IDs.fromText @ProjectReleaseShortHand branchRef + +codebaseForBranchRef :: SyncV2.BranchRef -> (ExceptT CodebaseLoadingError WebApp Codebase.CodebaseEnv) +codebaseForBranchRef branchRef = do + case parseBranchRef branchRef of + Left err -> throwError (CodebaseLoadingErrorInvalidBranchRef err branchRef) + Right (Left (ProjectReleaseShortHand {userHandle, projectSlug})) -> do + let projectShortHand = ProjectShortHand {userHandle, projectSlug} + (Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do + project <- PGQ.projectByShortHand projectShortHand `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound $ projectShortHand) + pure (project, Nothing) + authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef) + let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId + pure $ Codebase.codebaseEnv authZToken codebaseLoc + Right (Right (ProjectBranchShortHand {userHandle, projectSlug, contributorHandle})) -> do + let projectShortHand = ProjectShortHand {userHandle, projectSlug} + (Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do + project <- (PGQ.projectByShortHand projectShortHand) `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound projectShortHand) + mayContributorUserId <- for contributorHandle \ch -> fmap user_id $ (UserQ.userByHandle ch) `whenNothingM` throwError (CodebaseLoadingErrorUserNotFound ch) + pure (project, mayContributorUserId) + authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef) + let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId + pure $ Codebase.codebaseEnv authZToken codebaseLoc diff --git a/src/Share/Web/UCM/SyncCommon/Types.hs b/src/Share/Web/UCM/SyncCommon/Types.hs new file mode 100644 index 00000000..d73d104c --- /dev/null +++ b/src/Share/Web/UCM/SyncCommon/Types.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE DataKinds #-} + +module Share.Web.UCM.SyncCommon.Types (CodebaseLoadingError (..)) where + +import Data.Text.Encoding qualified as Text +import Servant +import Share.IDs +import Share.IDs qualified as IDs +import Share.Prelude +import Share.Utils.Logging qualified as Logging +import Share.Web.Errors +import Unison.SyncCommon.Types + +data CodebaseLoadingError + = CodebaseLoadingErrorProjectNotFound ProjectShortHand + | CodebaseLoadingErrorUserNotFound UserHandle + | CodebaseLoadingErrorNoReadPermission BranchRef + | CodebaseLoadingErrorInvalidBranchRef Text BranchRef + deriving stock (Show) + deriving (Logging.Loggable) via Logging.ShowLoggable Logging.UserFault CodebaseLoadingError + +instance ToServerError CodebaseLoadingError where + toServerError = \case + CodebaseLoadingErrorProjectNotFound projectShortHand -> (ErrorID "codebase-loading:project-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "Project not found: " <> (IDs.toText projectShortHand)}) + CodebaseLoadingErrorUserNotFound userHandle -> (ErrorID "codebase-loading:user-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "User not found: " <> (IDs.toText userHandle)}) + CodebaseLoadingErrorNoReadPermission branchRef -> (ErrorID "codebase-loading:no-read-permission", Servant.err403 {errBody = from . Text.encodeUtf8 $ "No read permission for branch ref: " <> (unBranchRef branchRef)}) + CodebaseLoadingErrorInvalidBranchRef err branchRef -> (ErrorID "codebase-loading:invalid-branch-ref", Servant.err400 {errBody = from . Text.encodeUtf8 $ "Invalid branch ref: " <> err <> " " <> (unBranchRef branchRef)}) diff --git a/src/Share/Web/UCM/SyncV2/Impl.hs b/src/Share/Web/UCM/SyncV2/Impl.hs index c73bdbb1..050b07b5 100644 --- a/src/Share/Web/UCM/SyncV2/Impl.hs +++ b/src/Share/Web/UCM/SyncV2/Impl.hs @@ -7,34 +7,29 @@ import Codec.Serialise qualified as CBOR import Conduit qualified as C import Control.Concurrent.STM qualified as STM import Control.Concurrent.STM.TBMQueue qualified as STM -import Control.Monad.Except (ExceptT (ExceptT), withExceptT) +import Control.Monad.Except (withExceptT) import Control.Monad.Trans.Except (runExceptT) import Data.Binary.Builder qualified as Builder import Data.Set qualified as Set -import Data.Text.Encoding qualified as Text import Data.Vector qualified as Vector import Ki.Unlifted qualified as Ki import Servant import Servant.Conduit (ConduitToSourceIO (..)) import Servant.Types.SourceT (SourceT (..)) import Servant.Types.SourceT qualified as SourceT -import Share.Codebase qualified as Codebase -import Share.IDs (ProjectBranchShortHand (..), ProjectReleaseShortHand (..), ProjectShortHand (..), UserHandle, UserId) +import Share.IDs (UserId) import Share.IDs qualified as IDs import Share.Postgres qualified as PG import Share.Postgres.Causal.Queries qualified as CausalQ import Share.Postgres.Cursors qualified as Cursor -import Share.Postgres.Queries qualified as PGQ -import Share.Postgres.Users.Queries qualified as UserQ import Share.Prelude -import Share.Project (Project (..)) -import Share.User (User (..)) import Share.Utils.Logging qualified as Logging import Share.Utils.Unison (hash32ToCausalHash) import Share.Web.App -import Share.Web.Authorization qualified as AuthZ import Share.Web.Errors import Share.Web.UCM.Sync.HashJWT qualified as HashJWT +import Share.Web.UCM.SyncCommon.Impl +import Share.Web.UCM.SyncCommon.Types import Share.Web.UCM.SyncV2.Queries qualified as SSQ import Share.Web.UCM.SyncV2.Types (IsCausalSpine (..), IsLibRoot (..)) import U.Codebase.Sqlite.Orphans () @@ -58,17 +53,6 @@ server mayUserId = causalDependenciesStream = causalDependenciesStreamImpl mayUserId } -parseBranchRef :: SyncV2.BranchRef -> Either Text (Either ProjectReleaseShortHand ProjectBranchShortHand) -parseBranchRef (SyncV2.BranchRef branchRef) = - case parseRelease <|> parseBranch of - Just a -> Right a - Nothing -> Left $ "Invalid repo info: " <> branchRef - where - parseBranch :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand) - parseBranch = fmap Right . eitherToMaybe $ IDs.fromText @ProjectBranchShortHand branchRef - parseRelease :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand) - parseRelease = fmap Left . eitherToMaybe $ IDs.fromText @ProjectReleaseShortHand branchRef - downloadEntitiesStreamImpl :: Maybe UserId -> SyncV2.DownloadEntitiesRequest -> WebApp (SourceIO (SyncV2.CBORStream SyncV2.DownloadEntitiesChunk)) downloadEntitiesStreamImpl mayCallerUserId (SyncV2.DownloadEntitiesRequest {causalHash = causalHashJWT, branchRef, knownHashes}) = do either emitErr id <$> runExceptT do @@ -142,43 +126,6 @@ queueToStream q = do loop loop -data CodebaseLoadingError - = CodebaseLoadingErrorProjectNotFound ProjectShortHand - | CodebaseLoadingErrorUserNotFound UserHandle - | CodebaseLoadingErrorNoReadPermission SyncV2.BranchRef - | CodebaseLoadingErrorInvalidBranchRef Text SyncV2.BranchRef - deriving stock (Show) - deriving (Logging.Loggable) via Logging.ShowLoggable Logging.UserFault CodebaseLoadingError - -instance ToServerError CodebaseLoadingError where - toServerError = \case - CodebaseLoadingErrorProjectNotFound projectShortHand -> (ErrorID "codebase-loading:project-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "Project not found: " <> (IDs.toText projectShortHand)}) - CodebaseLoadingErrorUserNotFound userHandle -> (ErrorID "codebase-loading:user-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "User not found: " <> (IDs.toText userHandle)}) - CodebaseLoadingErrorNoReadPermission branchRef -> (ErrorID "codebase-loading:no-read-permission", Servant.err403 {errBody = from . Text.encodeUtf8 $ "No read permission for branch ref: " <> (SyncV2.unBranchRef branchRef)}) - CodebaseLoadingErrorInvalidBranchRef err branchRef -> (ErrorID "codebase-loading:invalid-branch-ref", Servant.err400 {errBody = from . Text.encodeUtf8 $ "Invalid branch ref: " <> err <> " " <> (SyncV2.unBranchRef branchRef)}) - -codebaseForBranchRef :: SyncV2.BranchRef -> (ExceptT CodebaseLoadingError WebApp Codebase.CodebaseEnv) -codebaseForBranchRef branchRef = do - case parseBranchRef branchRef of - Left err -> throwError (CodebaseLoadingErrorInvalidBranchRef err branchRef) - Right (Left (ProjectReleaseShortHand {userHandle, projectSlug})) -> do - let projectShortHand = ProjectShortHand {userHandle, projectSlug} - (Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do - project <- PGQ.projectByShortHand projectShortHand `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound $ projectShortHand) - pure (project, Nothing) - authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef) - let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId - pure $ Codebase.codebaseEnv authZToken codebaseLoc - Right (Right (ProjectBranchShortHand {userHandle, projectSlug, contributorHandle})) -> do - let projectShortHand = ProjectShortHand {userHandle, projectSlug} - (Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do - project <- (PGQ.projectByShortHand projectShortHand) `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound projectShortHand) - mayContributorUserId <- for contributorHandle \ch -> fmap user_id $ (UserQ.userByHandle ch) `whenNothingM` throwError (CodebaseLoadingErrorUserNotFound ch) - pure (project, mayContributorUserId) - authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef) - let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId - pure $ Codebase.codebaseEnv authZToken codebaseLoc - -- | Run an IO action in the background while streaming the results. -- -- Servant doesn't provide any easier way to do bracketing like this, all the IO must be diff --git a/src/Share/Web/UCM/SyncV3/API.hs b/src/Share/Web/UCM/SyncV3/API.hs new file mode 100644 index 00000000..f3109188 --- /dev/null +++ b/src/Share/Web/UCM/SyncV3/API.hs @@ -0,0 +1,21 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} + +module Share.Web.UCM.SyncV3.API + ( API, + Routes (..), + ) +where + +import GHC.Generics (Generic) +import Servant +import Servant.API.WebSocket (WebSocket) + +data Routes mode = Routes + { downloadEntities :: mode :- "download" :> DownloadEntitiesEndpoint + } + deriving stock (Generic) + +type API = NamedRoutes Routes + +type DownloadEntitiesEndpoint = WebSocket diff --git a/src/Share/Web/UCM/SyncV3/Impl.hs b/src/Share/Web/UCM/SyncV3/Impl.hs new file mode 100644 index 00000000..5daf05ee --- /dev/null +++ b/src/Share/Web/UCM/SyncV3/Impl.hs @@ -0,0 +1,226 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeOperators #-} + +module Share.Web.UCM.SyncV3.Impl (server) where + +import Control.Lens hiding ((.=)) +import Control.Monad.Cont (ContT (..), MonadCont (..)) +import Control.Monad.Except (runExceptT) +import Data.Set qualified as Set +import Data.Set.Lens (setOf) +import Data.Vector (Vector) +import GHC.Natural +import Ki.Unlifted qualified as Ki +import Network.WebSockets qualified as WS +import Share.Codebase (CodebaseEnv) +import Share.IDs (UserId) +import Share.Postgres qualified as PG +import Share.Prelude +import Share.Web.App +import Share.Web.Authentication.Types qualified as AuthN +import Share.Web.Authorization qualified as AuthZ +import Share.Web.UCM.Sync.HashJWT qualified as HashJWT +import Share.Web.UCM.SyncCommon.Impl (codebaseForBranchRef) +import Share.Web.UCM.SyncCommon.Types +import Share.Web.UCM.SyncV3.API qualified as SyncV3 +import Share.Web.UCM.SyncV3.Queries qualified as Q +import U.Codebase.Sqlite.Orphans () +import Unison.Debug qualified as Debug +import Unison.Hash32 (Hash32) +import Unison.Share.API.Hash (HashJWT, HashJWTClaims (..)) +import Unison.Share.API.Hash qualified as HashJWT +import Unison.SyncV3.Types +import Unison.SyncV3.Utils (entityDependencies) +import Unison.Util.Websockets (Queues (..), withQueues) +import UnliftIO qualified +import UnliftIO.STM + +-- Amount of entities to buffer from the network into the send/recv queues. +sendBufferSize :: Natural +sendBufferSize = 100 + +recvBufferSize :: Natural +recvBufferSize = 100 + +-- data StreamInitInfo = StreamInitInfo + +-- streamSettings :: StreamInitInfo +-- streamSettings = StreamInitInfo + +server :: Maybe UserId -> SyncV3.Routes WebAppServer +server mayUserId = + SyncV3.Routes + { downloadEntities = downloadEntitiesImpl mayUserId + } + +type SyncM = ContT (Either SyncError ()) WebApp + +downloadEntitiesImpl :: Maybe UserId -> WS.Connection -> WebApp () +downloadEntitiesImpl mayCallerUserId conn = do + Debug.debugLogM Debug.Temp "Got connection" + -- Auth is currently done via HashJWTs + _authZReceipt <- AuthZ.checkDownloadFromUserCodebase + doSyncEmitter mayCallerUserId conn + +-- | Given a helper which understands how to wire things into its backend, This +-- implements the sync emitter logic which is independent of the backend. +doSyncEmitter :: + Maybe UserId -> + WS.Connection -> + WebApp () +doSyncEmitter mayCallerUserId conn = do + withQueues @(MsgOrError SyncError (FromEmitterMessage Hash32 Text)) @(MsgOrError SyncError (FromReceiverMessage HashJWT Hash32)) + recvBufferSize + sendBufferSize + conn + \(q@Queues {receive}) -> do + handleErr q $ do + withErrorCont \onErr -> do + Debug.debugLogM Debug.Temp "Got queues" + let recvM :: SyncM (FromReceiverMessage HashJWT Hash32) + recvM = do + result <- liftIO $ atomically receive + Debug.debugM Debug.Temp "Received: " result + case result of + Msg msg -> pure msg + Err err -> onErr err + + Debug.debugLogM Debug.Temp "Waiting for init message" + initMsg <- recvM + Debug.debugM Debug.Temp "Got init: " initMsg + syncState <- case initMsg of + ReceiverInitStream initMsg -> initialize onErr mayCallerUserId initMsg + other -> onErr $ InitializationError ("Expected ReceiverInitStream message, got: " <> tShow other) + Debug.debugLogM Debug.Temp "Initialized sync state, starting sync process." + lift (shareEmitter syncState q) + >>= maybe (pure ()) (onErr) + where + -- Given a continuation-based action, run it in the base monad, capturing any early exits + withErrorCont :: + ((forall x. SyncError -> SyncM x) -> SyncM ()) -> + WebApp (Either SyncError ()) + withErrorCont action = do + flip runContT pure $ callCC \cc -> do + Right <$> action (fmap absurd . cc . Left) + -- If we get an error, send it to the client then shut down. + handleErr :: + (Show err) => + Queues (MsgOrError err a) o -> + WebApp (Either err ()) -> + WebApp () + handleErr (Queues {send, shutdown}) action = do + action >>= \case + Left err -> do + Debug.debugM Debug.Temp "Sync error, shutting down: " err + atomically $ do + send (Err err) + liftIO $ shutdown + Right r -> pure r + +initialize :: (forall x. SyncError -> SyncM x) -> (Maybe UserId) -> InitMsg HashJWT -> SyncM (SyncState sh Hash32) +initialize onErr caller InitMsg {initMsgRootCausal, initMsgBranchRef} = do + let decoded = HashJWT.decodeHashJWT initMsgRootCausal + Debug.debugM Debug.Temp "Decoded root causal hash jwt" decoded + Debug.debugM Debug.Temp "Caller: " caller + HashJWTClaims {hash = initialCausalHash} <- + lift (HashJWT.verifyHashJWT caller initMsgRootCausal) >>= \case + Right ch -> pure ch + Left err -> onErr $ HashJWTVerificationError (AuthN.authErrMsg err) + validRequestsVar <- newTVarIO (Set.singleton (CausalEntity, initialCausalHash)) + requestedEntitiesVar <- newTVarIO (Set.singleton (CausalEntity, initialCausalHash)) + entitiesAlreadySentVar <- newTVarIO Set.empty + (lift . runExceptT $ codebaseForBranchRef initMsgBranchRef) >>= \case + Left err -> case err of + CodebaseLoadingErrorNoReadPermission {} -> onErr $ NoReadPermission initMsgBranchRef + CodebaseLoadingErrorProjectNotFound {} -> onErr $ ProjectNotFound initMsgBranchRef + CodebaseLoadingErrorUserNotFound {} -> onErr $ UserNotFound initMsgBranchRef + CodebaseLoadingErrorInvalidBranchRef msg _ -> onErr $ InvalidBranchRef msg initMsgBranchRef + Right codebase -> + pure $ + SyncState + { codebase, + validRequestsVar, + requestedEntitiesVar, + entitiesAlreadySentVar + } + +data SyncState sh hash = SyncState + { codebase :: CodebaseEnv, + -- To avoid needing to sign HashJWTs for every hash we can just keep track of which hashes we've referenced and check + -- against this set when receiving requests. + validRequestsVar :: TVar (Set (EntityKind, hash)), + -- Entities which have been requested by the client but not yet sent. + requestedEntitiesVar :: TVar (Set (EntityKind, hash)), + -- Hashes which have been sent to the client + entitiesAlreadySentVar :: TVar (Set (EntityKind, hash)) + -- Hash mappings we've already sent to the client. + -- mappedHashesVar :: TVar (Map sh hash) + } + +shareEmitter :: + (SyncState HashTag Hash32) -> + Queues (MsgOrError SyncError (FromEmitterMessage Hash32 Text)) (MsgOrError SyncError (FromReceiverMessage HashJWT Hash32)) -> + WebApp (Maybe SyncError) +shareEmitter SyncState {requestedEntitiesVar, entitiesAlreadySentVar, validRequestsVar, codebase} (Queues {send, receive}) = do + Ki.scoped $ \scope -> do + errVar <- newEmptyTMVarIO + let onErrSTM :: SyncError -> STM () + onErrSTM e = do + UnliftIO.putTMVar errVar e + + Debug.debugLogM Debug.Temp "Launching workers" + Ki.fork scope $ sendWorker onErrSTM + Ki.fork scope $ receiveWorker onErrSTM + Debug.debugLogM Debug.Temp "Waiting on errors or completion..." + atomically ((Ki.awaitAll scope $> Nothing) <|> (Just <$> UnliftIO.takeTMVar errVar)) + where + sendWorker :: (SyncError -> STM ()) -> WebApp () + sendWorker onErrSTM = forever $ do + (validRequests, reqs) <- atomically $ do + reqs <- readTVar requestedEntitiesVar + writeTVar requestedEntitiesVar Set.empty + alreadySent <- readTVar entitiesAlreadySentVar + Debug.debugM Debug.Temp "Processing Requested entities: " reqs + validRequests <- readTVar validRequestsVar + pure (validRequests, reqs `Set.difference` alreadySent) + let forbiddenRequests = Set.difference reqs validRequests + validatedRequests <- + if not (Set.null forbiddenRequests) + then do + atomically (onErrSTM (ForbiddenEntityRequest forbiddenRequests)) + pure $ Set.difference validRequests forbiddenRequests + else do + pure validRequests + Debug.debugM Debug.Temp "Validated requests: " validatedRequests + when (not $ Set.null validatedRequests) $ do + Debug.debugM Debug.Temp "Fetching Entities." validatedRequests + newEntities <- fetchEntities codebase validatedRequests + Debug.debugM Debug.Temp "Fetched entities: " (length newEntities) + -- Do work outside of transactions to avoid conflicts + deps <- UnliftIO.evaluate $ foldMap entityDependencies newEntities + Debug.debugLogM Debug.Temp "Adding new valid requests" + atomically $ modifyTVar' validRequestsVar (\s -> Set.union s deps) + Debug.debugM Debug.Temp "Sending entities: " (length newEntities) + atomically $ do + let newHashes = setOf (folded . to (entityKind &&& entityHash)) newEntities + modifyTVar' entitiesAlreadySentVar (Set.union newHashes) + for_ newEntities \entity -> do + send $ Msg (EmitterEntityMsg entity) + + receiveWorker :: (SyncError -> STM ()) -> WebApp () + receiveWorker onErrSTM = forever $ do + atomically $ do + receive >>= \case + Err err -> onErrSTM err + Msg (ReceiverInitStream {}) -> onErrSTM (InitializationError "Received duplicate ReceiverInitStream message") + Msg (ReceiverEntityRequest (EntityRequestMsg {hashes})) -> do + Debug.debugM Debug.Temp "Got new entity requests" hashes + modifyTVar' requestedEntitiesVar (\s -> Set.union s (Set.fromList hashes)) + +fetchEntities :: CodebaseEnv -> Set (EntityKind, Hash32) -> WebApp (Vector (Entity Hash32 Text)) +fetchEntities codebase reqs = do + PG.runTransaction $ Q.fetchSerialisedEntities codebase reqs + +-- entityDependencies :: Entity hash text -> Set (EntityKind, Hash32) +-- entityDependencies (Entity {entityData}) = do diff --git a/src/Share/Web/UCM/SyncV3/Queries.hs b/src/Share/Web/UCM/SyncV3/Queries.hs new file mode 100644 index 00000000..136a07c1 --- /dev/null +++ b/src/Share/Web/UCM/SyncV3/Queries.hs @@ -0,0 +1,60 @@ +module Share.Web.UCM.SyncV3.Queries + ( fetchSerialisedEntities, + ) +where + +import Data.Vector (Vector) +import Share.Codebase.Types (CodebaseEnv (..)) +import Share.Postgres +import Share.Prelude +import Unison.SyncV3.Types +import U.Codebase.Sqlite.TempEntity (TempEntity) +import Unison.Hash32 (Hash32) +import Unison.SyncV2.Types (CBORBytes) + +fetchSerialisedEntities :: (QueryM m) => CodebaseEnv -> Set (EntityKind, Hash32) -> m (Vector (Entity Hash32 Text)) +fetchSerialisedEntities (CodebaseEnv {codebaseOwner}) requestedEntities = + do + queryVectorRows @(EntityKind, CBORBytes TempEntity, Hash32, EntityDepth) + [sql| + WITH requested(kind, hash) AS ( + SELECT kind, hash FROM ^{toTable $ toList requestedEntities} AS t(kind, hash) + ) + (SELECT req.kind, bytes.bytes, ch.base32, cd.depth + FROM requested req + JOIN component_hashes ch ON req.hash = ch.base32 + JOIN serialized_components sc ON sc.user_id = #{codebaseOwner} AND ch.id = sc.component_hash_id + JOIN bytes ON sc.bytes_id = bytes.id + JOIN component_depth cd ON ch.id = cd.component_hash_id + WHERE req.kind = 'component' + ) + UNION ALL + (SELECT req.kind, bytes.bytes, req.hash, pd.depth + FROM requested req + JOIN patches p ON req.hash = p.hash + JOIN serialized_patches sp ON p.id = sp.patch_id + JOIN bytes ON sp.bytes_id = bytes.id + JOIN patch_depth pd ON p.id = pd.patch_id + WHERE req.kind = 'patch' + ) + UNION ALL + (SELECT req.kind, bytes.bytes, req.hash, nd.depth + FROM requested req + JOIN branch_hashes bh ON req.hash = bh.base32 + JOIN serialized_namespaces sn ON bh.id = sn.namespace_hash_id + JOIN bytes ON sn.bytes_id = bytes.id + JOIN namespace_depth nd ON bh.id = nd.namespace_hash_id + WHERE req.kind = 'namespace' + ) + UNION ALL + -- TODO: Should probably join in a batch of causal spines here too + -- to improve parallelism and avoid long-spine bottlenecks. + (SELECT req.kind, bytes.bytes, req.hash, cd.depth + FROM requested req + JOIN causals c ON req.hash = c.hash + JOIN serialized_causals sc ON c.id = sc.causal_id + JOIN bytes ON sc.bytes_id = bytes.id + JOIN causal_depth cd ON c.id = cd.causal_id + ) + |] + <&> fmap (\(entityKind, entityData, entityHash, entityDepth) -> Entity {entityKind, entityData, entityHash, entityDepth}) diff --git a/unison b/unison index d85de688..31655c33 160000 --- a/unison +++ b/unison @@ -1 +1 @@ -Subproject commit d85de68861c65d6419a6ac9df8022400adb27f4d +Subproject commit 31655c33f2bff2c6b23fc366b22bb3bd0ad24a9a