Skip to content

Commit

Permalink
introduce MonadSession
Browse files Browse the repository at this point in the history
  • Loading branch information
ners committed Jul 6, 2024
1 parent fad7bc2 commit 02595d6
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 93 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## 0.3.1.0 -- 2024-06-24

* Expose SessionT, lift all functions to MonadIO
* Introduce SessionT and MonadSession
* Lift all functions to MonadSession
* Add Session.getAllVersionedDocs

## 0.3.0.0 -- 2024-04-04
Expand Down
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions lsp-client.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ common common
bytestring >= 0.9 && < 0.13,
lens >= 5 && < 6,
lsp-types >= 2 && < 3,
unliftio >= 0.2 && < 0.3,
row-types,
unliftio >= 0.2 && < 0.3,

library
import: common
Expand All @@ -73,8 +73,9 @@ library
stm,
text,
text-rope,
unordered-containers,
transformers,
unix-compat >= 0.7.1 && < 0.8,
unordered-containers,
other-modules:
Control.Concurrent.STM.TVar.Extra,
exposed-modules:
Expand Down
187 changes: 100 additions & 87 deletions src/Language/LSP/Client/Session.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ import Control.Lens hiding (Empty, List)
import Control.Lens.Extras (is)
import Control.Monad (unless, when)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Reader (ReaderT, asks)
import Control.Monad.Reader (ReaderT (runReaderT), ask, asks)
import Control.Monad.State (StateT, execState)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Data.Default (def)
import Data.Foldable (foldl', foldr', forM_, toList)
import Data.Function (on)
Expand Down Expand Up @@ -95,6 +96,15 @@ type SessionT = ReaderT SessionState

type Session = SessionT IO

class (Monad m) => MonadSession m where
liftSession :: forall a. Session a -> m a

instance {-# OVERLAPPING #-} (MonadIO m) => MonadSession (SessionT m) where
liftSession a = liftIO . runReaderT a =<< ask

instance {-# OVERLAPPABLE #-} (MonadTrans t, MonadSession m) => MonadSession (t m) where
liftSession = lift . liftSession

documentChangeUri :: DocumentChange -> Uri
documentChangeUri (InL x) = x ^. textDocument . uri
documentChangeUri (InR (InL x)) = x ^. uri
Expand All @@ -105,22 +115,22 @@ documentChangeUri (InR (InR (InR x))) = x ^. uri
Note that this does not provide any business logic beyond updating the session state; you most likely
want to use `sendRequest` and `receiveNotification` to register callbacks for specific messages.
-}
handleServerMessage :: forall io. (MonadIO io) => FromServerMessage -> SessionT io ()
handleServerMessage :: (MonadSession m) => FromServerMessage -> m ()
handleServerMessage (FromServerMess SMethod_Progress req) =
when (anyOf folded ($ req ^. params . value) [is _workDoneProgressBegin, is _workDoneProgressEnd]) $
liftSession . when (anyOf folded ($ req ^. params . value) [is _workDoneProgressBegin, is _workDoneProgressEnd]) $
asks progressTokens
>>= liftIO
. flip modifyTVarIO (HashSet.insert $ req ^. params . token)
handleServerMessage (FromServerMess SMethod_ClientRegisterCapability req) =
asks serverCapabilities >>= liftIO . flip modifyTVarIO (HashMap.union (HashMap.fromList newRegs))
liftSession $ asks serverCapabilities >>= liftIO . flip modifyTVarIO (HashMap.union (HashMap.fromList newRegs))
where
regs = req ^.. params . registrations . traversed . to toSomeRegistration . _Just
newRegs = (\sr@(SomeRegistration r) -> (r ^. id, sr)) <$> regs
handleServerMessage (FromServerMess SMethod_ClientUnregisterCapability req) =
asks serverCapabilities >>= liftIO . flip modifyTVarIO (flip (foldr' HashMap.delete) unRegs)
liftSession $ asks serverCapabilities >>= liftIO . flip modifyTVarIO (flip (foldr' HashMap.delete) unRegs)
where
unRegs = (^. id) <$> req ^. params . unregisterations
handleServerMessage (FromServerMess SMethod_WorkspaceApplyEdit r) = do
handleServerMessage (FromServerMess SMethod_WorkspaceApplyEdit r) = liftSession $ do
-- First, prefer the versioned documentChanges field
allChangeParams <- case r ^. params . edit . documentChanges of
Just cs -> do
Expand Down Expand Up @@ -167,7 +177,7 @@ handleServerMessage (FromServerMess SMethod_WorkspaceApplyEdit r) = do
logger :: LogAction (StateT VFS Identity) (WithSeverity VfsLog)
logger = LogAction $ \WithSeverity{..} -> case getSeverity of Error -> error $ show getMsg; _ -> pure ()

checkIfNeedsOpened :: Uri -> SessionT io ()
checkIfNeedsOpened :: Uri -> Session ()
checkIfNeedsOpened uri = do
isOpen <- asks vfs >>= liftIO . readTVarIO <&> has (vfsMap . ix (toNormalizedUri uri))

Expand Down Expand Up @@ -200,19 +210,19 @@ handleServerMessage (FromServerMess SMethod_WorkspaceApplyEdit r) = do
getParamsFromDocumentChange (InL textDocumentEdit) = getParamsFromTextDocumentEdit textDocumentEdit
getParamsFromDocumentChange _ = Nothing

bumpNewestVersion :: OptionalVersionedTextDocumentIdentifier -> SessionT io OptionalVersionedTextDocumentIdentifier
bumpNewestVersion :: OptionalVersionedTextDocumentIdentifier -> Session OptionalVersionedTextDocumentIdentifier
bumpNewestVersion OptionalVersionedTextDocumentIdentifier{_uri, _version = InL _} = do
VersionedTextDocumentIdentifier{_version} <- head <$> textDocumentVersions _uri
pure OptionalVersionedTextDocumentIdentifier{_version = InL _version, ..}
bumpNewestVersion i = pure i

-- For a uri returns an infinite list of versions [n+1,n+2,...]
-- where n is the current version
textDocumentVersions :: Uri -> SessionT io [VersionedTextDocumentIdentifier]
textDocumentVersions :: Uri -> Session [VersionedTextDocumentIdentifier]
textDocumentVersions _uri = do
tail . iterate (version +~ 1) <$> getVersionedDoc TextDocumentIdentifier{_uri}

textDocumentEdits :: Uri -> [TextEdit] -> SessionT io [TextDocumentEdit]
textDocumentEdits :: Uri -> [TextEdit] -> Session [TextDocumentEdit]
textDocumentEdits uri edits = do
versions <- textDocumentVersions uri
pure $
Expand All @@ -236,20 +246,20 @@ handleServerMessage (FromServerMess SMethod_WorkspaceApplyEdit r) = do
{ _contentChanges = concat . toList $ toList . (^. contentChanges) <$> params
, _textDocument = head params ^. textDocument
}
handleServerMessage (FromServerMess SMethod_WindowWorkDoneProgressCreate req) = sendResponse req $ Right Null
handleServerMessage (FromServerMess SMethod_WindowWorkDoneProgressCreate req) = liftSession . sendResponse req $ Right Null
handleServerMessage _ = pure ()

{- | Sends a request to the server, with a callback that fires when the response arrives.
Multiple requests can be waiting at the same time.
-}
sendRequest
:: forall (m :: Method 'ClientToServer 'Request) io
. (TMessage m ~ TRequestMessage m, MonadIO io)
=> SMethod m
-> MessageParams m
-> (TResponseMessage m -> IO ())
-> SessionT io (LspId m)
sendRequest requestMethod _params requestCallback = do
:: forall (method :: Method 'ClientToServer 'Request) m
. (TMessage method ~ TRequestMessage method, MonadSession m)
=> SMethod method
-> MessageParams method
-> (TResponseMessage method -> IO ())
-> m (LspId method)
sendRequest requestMethod _params requestCallback = liftSession $ do
_id <- asks lastRequestId >>= liftIO . overTVarIO (+ 1) <&> IdInt
asks pendingRequests >>= liftIO . flip modifyTVarIO (updateRequestMap _id RequestCallback{..})
sendMessage $ fromClientReq TRequestMessage{_jsonrpc = "2.0", _method = requestMethod, ..}
Expand All @@ -259,22 +269,22 @@ sendRequest requestMethod _params requestCallback = do
Users of this library cannot register callbacks to server requests, so this function is probably of no use to them.
-}
sendResponse
:: forall (m :: Method 'ServerToClient 'Request) io
. (MonadIO io)
=> TRequestMessage m
-> Either ResponseError (MessageResult m)
-> SessionT io ()
:: forall (method :: Method 'ServerToClient 'Request) m
. (MonadSession m)
=> TRequestMessage method
-> Either ResponseError (MessageResult method)
-> m ()
sendResponse TRequestMessage{..} _result =
sendMessage $ FromClientRsp _method TResponseMessage{_id = Just _id, ..}
liftSession . sendMessage $ FromClientRsp _method TResponseMessage{_id = Just _id, ..}

-- | Sends a request to the server and synchronously waits for its response.
request
:: forall (m :: Method 'ClientToServer 'Request) io
. (TMessage m ~ TRequestMessage m, MonadIO io)
=> SMethod m
-> MessageParams m
-> SessionT io (TResponseMessage m)
request method params = do
:: forall (method :: Method 'ClientToServer 'Request) m
. (TMessage method ~ TRequestMessage method, MonadSession m)
=> SMethod method
-> MessageParams method
-> m (TResponseMessage method)
request method params = liftSession $ do
done <- liftIO newEmptyMVar
void $ sendRequest method params $ putMVar done
liftIO $ takeMVar done
Expand All @@ -290,12 +300,12 @@ getResponseResult response = either err Prelude.id $ response ^. result

-- | Sends a notification to the server. Updates the VFS if the notification is a document update.
sendNotification
:: forall (m :: Method 'ClientToServer 'Notification) io
. (TMessage m ~ TNotificationMessage m, MonadIO io)
=> SMethod m
-> MessageParams m
-> SessionT io ()
sendNotification m params = do
:: forall (method :: Method 'ClientToServer 'Notification) m
. (TMessage method ~ TNotificationMessage method, MonadSession m)
=> SMethod method
-> MessageParams method
-> m ()
sendNotification m params = liftSession $ do
let n = TNotificationMessage "2.0" m params
vfs <- asks vfs
case m of
Expand All @@ -309,47 +319,49 @@ sendNotification m params = do
If multiple callbacks are registered for the same notification method, they will all be called.
-}
receiveNotification
:: forall (m :: Method 'ServerToClient 'Notification) io
. (TMessage m ~ TNotificationMessage m, MonadIO io)
=> SMethod m
-> (TMessage m -> IO ())
-> SessionT io ()
:: forall (method :: Method 'ServerToClient 'Notification) m
. (TMessage method ~ TNotificationMessage method, MonadSession m)
=> SMethod method
-> (TMessage method -> IO ())
-> m ()
receiveNotification method notificationCallback =
asks notificationHandlers
>>= liftIO
. flip
modifyTVarIO
( appendNotificationCallback method NotificationCallback{..}
)
liftSession $
asks notificationHandlers
>>= liftIO
. flip
modifyTVarIO
( appendNotificationCallback method NotificationCallback{..}
)

{- | Clears the registered callback for the given notification method, if any.
If multiple callbacks have been registered, this clears /all/ of them.
-}
clearNotificationCallback
:: forall (m :: Method 'ServerToClient 'Notification) io
. (MonadIO io)
=> SMethod m
-> SessionT io ()
:: forall (method :: Method 'ServerToClient 'Notification) m
. (MonadSession m)
=> SMethod method
-> m ()
clearNotificationCallback method =
asks notificationHandlers
>>= liftIO
. flip
modifyTVarIO
( removeNotificationCallback method
)
liftSession $
asks notificationHandlers
>>= liftIO
. flip
modifyTVarIO
( removeNotificationCallback method
)

-- | Queues a message to be sent to the server at the client's earliest convenience.
sendMessage :: (MonadIO io) => FromClientMessage -> SessionT io ()
sendMessage msg = asks outgoing >>= liftIO . atomically . (`writeTQueue` msg)
sendMessage :: (MonadSession m) => FromClientMessage -> m ()
sendMessage msg = liftSession $ asks outgoing >>= liftIO . atomically . (`writeTQueue` msg)

lspClientInfo :: Rec ("name" .== Text .+ "version" .== Maybe Text)
lspClientInfo = #name .== "lsp-client" .+ #version .== Just CURRENT_PACKAGE_VERSION

{- | Performs the initialisation handshake and synchronously waits for its completion.
When the function completes, the session is initialised.
-}
initialize :: (MonadIO io) => SessionT io ()
initialize = do
initialize :: (MonadSession m) => m ()
initialize = liftSession $ do
pid <- liftIO getProcessID
response <-
request
Expand Down Expand Up @@ -378,16 +390,16 @@ initialize = do
the server that one does exist.
-}
createDoc
:: (MonadIO io)
:: (MonadSession m)
=> FilePath
-- ^ The path to the document to open, __relative to the root directory__.
-> Text
-- ^ The text document's language identifier, e.g. @"haskell"@.
-> Text
-- ^ The content of the text document to create.
-> SessionT io TextDocumentIdentifier
-> m TextDocumentIdentifier
-- ^ The identifier of the document just created.
createDoc file language contents = do
createDoc file language contents = liftSession $ do
serverCaps <- asks serverCapabilities >>= liftIO . readTVarIO
clientCaps <- asks clientCapabilities
rootDir <- asks rootDir
Expand Down Expand Up @@ -435,17 +447,17 @@ createDoc file language contents = do
{- | Opens a text document that /exists on disk/, and sends a
@textDocument/didOpen@ notification to the server.
-}
openDoc :: (MonadIO io) => FilePath -> Text -> SessionT io TextDocumentIdentifier
openDoc file language = do
openDoc :: (MonadSession m) => FilePath -> Text -> m TextDocumentIdentifier
openDoc file language = liftSession $ do
rootDir <- asks rootDir
contents <- liftIO . Text.readFile $ rootDir </> file
openDoc' file language contents

{- | This is a variant of `openDoc` that takes the file content as an argument.
Use this is the file exists /outside/ of the current workspace.
-}
openDoc' :: (MonadIO io) => FilePath -> Text -> Text -> SessionT io TextDocumentIdentifier
openDoc' file language contents = do
openDoc' :: (MonadSession m) => FilePath -> Text -> Text -> m TextDocumentIdentifier
openDoc' file language contents = liftSession $ do
rootDir <- asks rootDir
let _uri = filePathToUri $ rootDir </> file
sendNotification
Expand All @@ -462,45 +474,46 @@ openDoc' file language contents = do
pure TextDocumentIdentifier{..}

-- | Closes a text document and sends a @textDocument/didClose@ notification to the server.
closeDoc :: (MonadIO io) => TextDocumentIdentifier -> SessionT io ()
closeDoc :: (MonadSession m) => TextDocumentIdentifier -> m ()
closeDoc docId =
sendNotification
SMethod_TextDocumentDidClose
DidCloseTextDocumentParams
{ _textDocument =
TextDocumentIdentifier
{ _uri = docId ^. uri
}
}
liftSession $
sendNotification
SMethod_TextDocumentDidClose
DidCloseTextDocumentParams
{ _textDocument =
TextDocumentIdentifier
{ _uri = docId ^. uri
}
}

-- | Changes a text document and sends a @textDocument/didChange@ notification to the server.
changeDoc :: (MonadIO io) => TextDocumentIdentifier -> [TextDocumentContentChangeEvent] -> SessionT io ()
changeDoc docId _contentChanges = do
changeDoc :: (MonadSession m) => TextDocumentIdentifier -> [TextDocumentContentChangeEvent] -> m ()
changeDoc docId _contentChanges = liftSession $ do
_textDocument <- getVersionedDoc docId <&> version +~ 1
sendNotification SMethod_TextDocumentDidChange DidChangeTextDocumentParams{..}

-- | Gets the Uri for the file relative to the session's root directory.
getDocUri :: (MonadIO io) => FilePath -> SessionT io Uri
getDocUri file = do
getDocUri :: (MonadSession m) => FilePath -> m Uri
getDocUri file = liftSession $ do
rootDir <- asks rootDir
pure . filePathToUri $ rootDir </> file

-- | The current text contents of a document.
documentContents :: (MonadIO io) => TextDocumentIdentifier -> SessionT io (Maybe Rope)
documentContents TextDocumentIdentifier{_uri} = do
documentContents :: (MonadSession m) => TextDocumentIdentifier -> m (Maybe Rope)
documentContents TextDocumentIdentifier{_uri} = liftSession $ do
vfs <- asks vfs >>= liftIO . readTVarIO
pure $ vfs ^? vfsMap . ix (toNormalizedUri _uri) . to _file_text

-- | Adds the current version to the document, as tracked by the session.
getVersionedDoc :: (MonadIO io) => TextDocumentIdentifier -> SessionT io VersionedTextDocumentIdentifier
getVersionedDoc TextDocumentIdentifier{_uri} = do
getVersionedDoc :: (MonadSession m) => TextDocumentIdentifier -> m VersionedTextDocumentIdentifier
getVersionedDoc TextDocumentIdentifier{_uri} = liftSession $ do
vfs <- asks vfs >>= liftIO . readTVarIO
let _version = fromMaybe 0 $ vfs ^? vfsMap . ix (toNormalizedUri _uri) . to virtualFileVersion
pure VersionedTextDocumentIdentifier{..}

-- | Get all the versioned documents tracked by the session.
getAllVersionedDocs :: (MonadIO io) => SessionT io [VersionedTextDocumentIdentifier]
getAllVersionedDocs = do
getAllVersionedDocs :: (MonadSession m) => m [VersionedTextDocumentIdentifier]
getAllVersionedDocs = liftSession $ do
vfs <- asks vfs >>= liftIO . readTVarIO
pure $
Map.toList (vfs ^. vfsMap) <&> \(nuri, vf) ->
Expand Down

0 comments on commit 02595d6

Please sign in to comment.