Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LSP's textDocument/definition #86

Open
krobelus opened this issue Oct 15, 2021 · 1 comment
Open

Support LSP's textDocument/definition #86

krobelus opened this issue Oct 15, 2021 · 1 comment

Comments

@krobelus
Copy link

Hi, I'm glad I found a language server for SMT-LIB.
Currently, the most important feature for me is go-to-definition
("textDocument/definition", so I can quickly navigate my formulas.
This includes mainly names declared by declare-* or define-* statements,
and maybe also let bindings.

I've hacked together my own language server in Python that works pretty well
for me. But I'm always happy to get rid of that and move to a real solution.
I'm not sure when I'll get to delve into OCaml, but for now I'll leave my
Python implementation here, in case it helps anyone. There's some boilerplate -
the actual routine to find definitions is really simple (and inefficient).
Let me know if I can help with examples/explanation.

(The next logical step would be to implement "textDocument/references", to list all uses of a name)

#!/usr/bin/env python3

# Dependencies:
# pip install python-language-server

import sys
import logging
import threading

from pyls import _utils, uris
from pyls_jsonrpc.dispatchers import MethodDispatcher
from pyls_jsonrpc.endpoint import Endpoint
from pyls_jsonrpc.streams import JsonRpcStreamReader, JsonRpcStreamWriter

from pyls.workspace import Workspace

"""
Toy language server that implements textDocument/definition

For example, given this file

```smt2
(declare-const x Int)

(assert (= x 123))
```

if the cursor is on the "x" in line 3, textDocument/definition will return
the position of the x in line 1.
"""

# SMTLIBLanguageServer is adadpted from pyls
log = logging.getLogger(__name__)
PARENT_PROCESS_WATCH_INTERVAL = 10  # 10 s
MAX_WORKERS = 64


class SMTLIBLanguageServer(MethodDispatcher):
    """ Implementation of the Microsoft VSCode Language Server Protocol
    https://github.com/Microsoft/language-server-protocol/blob/master/versions/protocol-1-x.md
    """

    def __init__(self, rx, tx, check_parent_process=False):
        self.workspace = None
        self.root_uri = None
        self.watching_thread = None
        self.workspaces = {}
        self.uri_workspace_mapper = {}

        self._jsonrpc_stream_reader = JsonRpcStreamReader(rx)
        self._jsonrpc_stream_writer = JsonRpcStreamWriter(tx)
        self._check_parent_process = check_parent_process
        self._endpoint = Endpoint(
            self, self._jsonrpc_stream_writer.write, max_workers=MAX_WORKERS)
        self._dispatchers = []
        self._shutdown = False

    def start(self):
        """Entry point for the server."""
        self._jsonrpc_stream_reader.listen(self._endpoint.consume)

    def __getitem__(self, item):
        """Override getitem to fallback through multiple dispatchers."""
        if self._shutdown and item != 'exit':
            # exit is the only allowed method during shutdown
            log.debug("Ignoring non-exit method during shutdown: %s", item)
            raise KeyError
        try:
            return super(SMTLIBLanguageServer, self).__getitem__(item)
        except KeyError:
            # Fallback through extra dispatchers
            for dispatcher in self._dispatchers:
                try:
                    return dispatcher[item]
                except KeyError:
                    continue

        raise KeyError()

    def m_shutdown(self, **_kwargs):
        self._shutdown=True
        return None

    def m_exit(self, **_kwargs):
        self._endpoint.shutdown()
        self._jsonrpc_stream_reader.close()
        self._jsonrpc_stream_writer.close()

    def _match_uri_to_workspace(self, uri):
        workspace_uri=_utils.match_uri_to_workspace(uri, self.workspaces)
        return self.workspaces.get(workspace_uri, self.workspace)

    def capabilities(self):
        server_capabilities={
            "definitionProvider": True,
        }
        log.info('Server capabilities: %s', server_capabilities)
        return server_capabilities

    def m_initialize(self, processId=None, rootUri=None, rootPath=None, initializationOptions=None, **_kwargs):
        log.debug('Language server initialized with %s %s %s %s',
                  processId, rootUri, rootPath, initializationOptions)
        if rootUri is None:
            rootUri=uris.from_fs_path(
                rootPath) if rootPath is not None else ''

        self.workspaces.pop(self.root_uri, None)
        self.root_uri = rootUri
        self.workspace = Workspace(rootUri, self._endpoint, None)
        self.workspaces[rootUri] = self.workspace

        if self._check_parent_process and processId is not None and self.watching_thread is None:
            def watch_parent_process(pid):
                # exit when the given pid is not alive
                if not _utils.is_process_alive(pid):
                    log.info("parent process %s is not alive, exiting!", pid)
                    self.m_exit()
                else:
                    threading.Timer(PARENT_PROCESS_WATCH_INTERVAL,
                                    watch_parent_process, args=[pid]).start()

            self.watching_thread = threading.Thread(
                target=watch_parent_process, args=(processId,))
            self.watching_thread.daemon = True
            self.watching_thread.start()
        return {'capabilities': self.capabilities()}

    def m_initialized(self, **_kwargs):
        pass

    def m_text_document__definition(self, textDocument=None, position=None, **_kwargs):
        doc_uri = textDocument["uri"]
        workspace = self._match_uri_to_workspace(doc_uri)
        doc = workspace.get_document(doc_uri) if doc_uri else None
        return smt_definition(doc, position)

    def m_text_document__did_close(self, textDocument=None, **_kwargs):
        pass
    def m_text_document__did_open(self, textDocument=None, **_kwargs):
        pass
    def m_text_document__did_change(self, contentChanges=None, textDocument=None, **_kwargs):
        pass
    def m_text_document__did_save(self, textDocument=None, **_kwargs):
        pass
    def m_text_document__completion(self, textDocument=None, **_kwargs):
        pass


def flatten(list_of_lists):
    return [item for lst in list_of_lists for item in lst]

def merge(list_of_dicts):
    return {k: v for dictionary in list_of_dicts for k, v in dictionary.items()}

def smt_definition(document, position):
    pos = definition(document.source, position["line"], position["character"])
    if pos is None:
        return None

    line, col, token = pos

    offset = 1 if len(token) == 1 else (len(token) + 1)
    if col == 0:
        line -= 1
        col = len(document.lines[line]) - offset
    else:
        col = col - offset

    return {
            'uri': document.uri,
            'range': {
                'start': {'line': line, 'character': col},
                'end': {'line': line, 'character': col},
            }
        }

def definition(source, cursor_line, cursor_character):
    nodes = list(parser().parse_smtlib(source))
    node_at_cursor = find_leaf_node_at(cursor_line, cursor_character, nodes)
    line, col, node = find_definition_for(node_at_cursor, nodes)
    if node is None:
        return None
    return line, col, node_at_cursor

def find_leaf_node_at(line, col, nodes):
    prev_line_end = -1
    prev_col_end = -1
    needle = (line, col)
    for line_end, col_end, node in nodes:
        prev_range = (prev_line_end-1, prev_col_end)
        cur_range = (line_end, col_end)
        if prev_range < needle < cur_range:
            if isinstance(node, str):
                return node
            else:
                node_at = find_leaf_node_at(line, col, node)
                assert node_at is not None
                return node_at
        prev_line_end = line_end
        prev_col_end = col_end
    return None

def stripprefix(x, prefix):
    if x.startswith(prefix):
        return x[len(prefix):]
    return x

def find_definition_for(needle, nodes):
    for node in nodes:
        line_end, col_end, n = node
        _, _, head = n[0]
        if not head.startswith("declare-") and not head.startswith("define-"):
            continue
        _, _, symbol = n[1]

        if head in ("declare-const", "define-const", "declare-fun", "define-fun", "define-fun-rec", "declare-datatype"):
            if symbol == needle:
                return n[1]
            continue

        if head in ("declare-datatypes", "define-funs-rec"):
            for i, tmp in enumerate(symbol):
                _, _, type_parameter_declaration = tmp
                _, _, type_name = type_parameter_declaration[0]
                if type_name == needle:
                    return type_parameter_declaration[0]
            if head == "declare-datatypes":
                constructor = dfs(needle, node)
                if constructor is not None:
                    return constructor
                constructor = dfs(stripprefix(needle, "is-"), node)
                if constructor is not None:
                    return constructor
            continue
        assert f"unsupported form: {head}"
    return -1, -1, None

def dfs(needle, node):
    assert isinstance(node, tuple)
    _, _, n = node
    if isinstance(n, str):
        if n == needle:
            return node
        else:
            return None
    for child in n:
        found = dfs(needle, child)
        if found is not None:
            return found
    return None

class parser:
    def __init__(self):
        self.pos = 0
        self.line = 0
        self.col = -1
        self.text = None

    def nextch(self):
        char = self.text[self.pos]
        self.pos += 1
        self.col += 1
        if char == "\n":
            self.line += 1
            self.col = 0
        return char

    def parse_smtlib(self, text):
        assert self.text is None
        self.text = text
        return self.parse_smtlib_aux()

    def parse_smtlib_aux(self):
        exprs = []
        cur_expr = None
        size = len(self.text)

        while self.pos < size:
            char = self.nextch()

            # Stolen from ddSMT's parser. Not fully SMT-LIB compliant but good enough.
            # String literals/quoted symbols
            if char in ('"', '|'):
                first_char = char
                literal = [char]
                # Read until terminating " or |
                while True:
                    if self.pos >= size:
                        return
                    char = self.nextch()
                    literal.append(char)
                    if char == first_char:
                        # Check is quote is escaped "a "" b" is one string literal
                        if char == '"' and self.pos < size and self.text[self.pos] == '"':
                            literal.append(self.text[self.pos])
                            self.nextch()
                            continue
                        break
                cur_expr.append((self.line, self.col, literal))
                continue

            # Comments
            if char == ';':
                # Read until newline
                while self.pos < size:
                    char = self.nextch()
                    if char == '\n':
                        break
                continue

            # Open s-expression
            if char == '(':
                cur_expr = []
                exprs.append(cur_expr)
                continue

            # Close s-expression
            if char == ')':
                cur_expr = exprs.pop()
                # Do we have nested s-expressions?
                if exprs:
                    exprs[-1].append((self.line, self.col, cur_expr))
                    cur_expr = exprs[-1]
                else:
                    yield self.line, self.col, cur_expr
                    cur_expr = None
                continue

            # Identifier
            if char not in (' ', '\t', '\n'):
                token = [char]
                while True:
                    if self.pos >= size:
                        return
                    char = self.text[self.pos]
                    if char in ('(', ')', ';'):
                        break
                    self.nextch()
                    if char in (' ', '\t', '\n'):
                        break
                    token.append(char)

                token = ''.join(token)

                # Append to current s-expression
                if cur_expr is not None:
                    cur_expr.append((self.line, self.col, token))
                else:
                    yield self.line, self.col, token

def serve():
    stdin = sys.stdin.buffer
    stdout = sys.stdout.buffer
    server = SMTLIBLanguageServer(stdin, stdout)
    server.start()

if __name__ == "__main__":
    if len(sys.argv) >= 2 and sys.argv[1] == "definition":
        line = int(sys.argv[2])
        col = int(sys.argv[3])
        print(definition(sys.stdin.read(), line, col))
    else:
        serve()
@Gbury
Copy link
Owner

Gbury commented Oct 24, 2021

Hi,
That's very interesting to know ! It's very useful to me to know what features to prioritize, so don't hesitate to report what you need/want Dolmen (and the LSP server) to do, ^^

After thinking about it this week, I now have a fairly good idea of how I want to implement the goto definition in Dolmen, and I'll get on that as soon as I have the time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants