为PyQt5的QCustomPlot建立接口调用存根(Stub)

适用于PyQt5的QCustomPlot绘图控件默认并不包含接口调用存根,导致在VSCode等工具中编码时,无法自动生成代码提示。此时,可以手动建立调用存根。

开始前,请确保PyQt5QCustomPlot-PyQt5等组件已安装完成,同时请安装mypyastor包:

pip install mypy
pip install astor

随后,运行下面的Python脚本:

import mypy
import mypy.stubgen
 
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QPen, QBrush, QColor
from QCustomPlot_PyQt5 import *
import sys
import ast
import astor
 
import os
 
class RemoveUnderscoreVars(ast.NodeTransformer):
    def visit_ClassDef(self, node):
        new_body = []
        for stmt in node.body:
            remove = False
            if isinstance(stmt, ast.AnnAssign):
                if isinstance(stmt.target, ast.Name) and stmt.target.id.startswith('_'):
                    remove = True
            elif isinstance(stmt, ast.Assign):
                for target in stmt.targets:
                    if isinstance(target, ast.Name) and target.id.startswith('_'):
                        remove = True
                        break
            if not remove:
                new_body.append(stmt)
        node.body = new_body
 
        if isinstance(node, ast.ClassDef):
            self.generic_visit(node) 
        return node
 
class RemoveFailedImports(ast.NodeTransformer):
    def visit_Import(self, node):
        try:
            for alias in node.names:
                __import__(alias.name)
            return node
        except ImportError:
            return None
 
    def visit_ImportFrom(self, node):
        try:
            module = __import__(node.module, fromlist=[name.name for name in node.names])
            valid_names = []
            for alias in node.names:
                if hasattr(module, alias.name):
                    valid_names.append(alias)
            if valid_names:
                node.names = valid_names
                return node
            return None
        except (ImportError, AttributeError):
            return None
 
class SymbolCollector(ast.NodeVisitor):
    def __init__(self):
        self.symbols = set()
 
    def visit_Import(self, node):
        for alias in node.names:
            if alias.asname:
                self.symbols.add(alias.asname)
            else:
                self.symbols.add(alias.name.split('.')[0])
 
    def visit_ImportFrom(self, node):
        for alias in node.names:
            if alias.asname:
                self.symbols.add(alias.asname)
            else:
                self.symbols.add(alias.name)
 
    def visit_ClassDef(self, node):
        self.symbols.add(node.name)
        self.generic_visit(node)
 
    def visit_FunctionDef(self, node):
        self.symbols.add(node.name)
        self.generic_visit(node)
 
class RemoveUnknownAnnotations(ast.NodeTransformer):
    def __init__(self, known_symbols):
        self.known_symbols = known_symbols
 
    def visit_AnnAssign(self, node):
        if self._is_unknown(node.annotation):
            return None
        return node
 
    def visit_FunctionDef(self, node):
        if node.returns and self._is_unknown(node.returns):
            node.returns = None
 
        for arg in node.args.args:
            if arg.annotation and self._is_unknown(arg.annotation):
                arg.annotation = None
 
        return node
 
    def _check_subscript_slice(self, slice_node: ast.AST) -> bool:
        if isinstance(slice_node, ast.Tuple):
            return any(self._is_unknown(elt) for elt in slice_node.elts)
        if isinstance(slice_node, ast.Index):
            return self._is_unknown(slice_node.value)
        return self._is_unknown(slice_node)
 
    def _is_unknown(self, node: ast.AST) -> bool:
        if isinstance(node, ast.BinOp):
            return (self._is_unknown(node.left) or 
                    self._is_unknown(node.right))
 
        if isinstance(node, ast.Subscript):
            return (self._is_unknown(node.value) or 
                    self._check_subscript_slice(node.slice))
 
        if isinstance(node, ast.Attribute):
            parts = []
            current = node
            while isinstance(current, ast.Attribute):
                parts.append(current.attr)
                current = current.value
            if isinstance(current, ast.Name):
                parts.append(current.id)
                parts.reverse()
                return ".".join(parts) not in self.known_symbols and parts[-1] not in self.known_symbols
            return True
 
        if isinstance(node, ast.Name):
            return node.id not in self.known_symbols
 
        if isinstance(node, ast.Subscript):
            return self._is_unknown(node.value)
 
        return False
 
def process_pyi(input_file, output_file):
    with open(input_file, 'r') as f:
        tree = ast.parse(f.read())
 
    # remove underscore vars
    tree = RemoveUnderscoreVars().visit(tree)
    ast.fix_missing_locations(tree)
 
    # remove failed imports
    tree = RemoveFailedImports().visit(tree)
    ast.fix_missing_locations(tree)
 
    # collect symbols
    collector = SymbolCollector()
    collector.visit(tree)
    known_symbols = collector.symbols
 
    # remove unknown annotations
    tree = RemoveUnknownAnnotations(known_symbols).visit(tree)
    ast.fix_missing_locations(tree)
 
    new_code = astor.to_source(tree)
    with open(output_file, 'w') as f:
        f.write(new_code)
 
if __name__ == '__main__':
    args = ['-m', 'QCustomPlot_PyQt5', '-o', 'tmp']
    options = mypy.stubgen.parse_options(args)
    mypy.stubgen.generate_stubs(options)
    os.makedirs("output", exist_ok=True)
    process_pyi('tmp/QCustomPlot_PyQt5.pyi', 'output/QCustomPlot_PyQt5.pyi')

执行后,会在脚本所在目录的output子目录中生成QCustomPlot_PyQt5.pyi接口描述文件,将该文件复制到安装了QCustomPlot-PyQt5的Python环境根目录的Lib/site-packages目录中(例如,如果您使用了名为DemoEnv的Anaconda环境,您可能需要将该文件复制到%ProgramData%\anaconda3\envs\DemoEnv\Lib\site-packages目录中)即可。

参考资料:

https://github.com/salsergey/QCustomPlot-PyQt/issues/12

it
除非特别注明,本页内容采用以下授权方式: Creative Commons Attribution-ShareAlike 3.0 License