Skip to content

Commit

Permalink
feat: expose Querier class
Browse files Browse the repository at this point in the history
  • Loading branch information
nohehf committed May 12, 2024
1 parent 7e198c0 commit 96ebb43
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,6 @@ docs/_build/
.python-version
# stack graph sqlite dbs
*.db
*.sqlite

.mypy_cache/
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
version = "0.0.4"
version = "0.0.5"
[tool.maturin]
features = ["pyo3/extension-module"]
34 changes: 34 additions & 0 deletions src/classes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ use std::fmt::Display;

use pyo3::prelude::*;

use stack_graphs::storage::SQLiteReader;
use tree_sitter_stack_graphs::cli::util::{SourcePosition, SourceSpan};

use crate::stack_graphs_wrapper::query_definition;

#[pyclass]
#[derive(Clone)]
pub enum Language {
Expand All @@ -24,6 +27,37 @@ pub struct Position {
column: usize,
}

#[pyclass]
pub struct Querier {
db_reader: SQLiteReader,
}

#[pymethods]
impl Querier {
#[new]
pub fn new(db_path: String) -> Self {
println!("Opening database: {}", db_path);
Querier {
db_reader: SQLiteReader::open(db_path).unwrap(),
}
}

pub fn definitions(&mut self, reference: Position) -> PyResult<Vec<Position>> {
let result = query_definition(reference.into(), &mut self.db_reader)?;

let positions: Vec<Position> = result
.into_iter()
.map(|r| r.targets)
.flatten()
.map(|t| t.into())
.collect();

Ok(positions)
}
}

// TODO(@nohehf): Indexer class

#[pymethods]
impl Position {
#[new]
Expand Down
27 changes: 4 additions & 23 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use pyo3::prelude::*;
mod classes;
mod stack_graphs_wrapper;

use classes::{Language, Position};
use classes::{Language, Position, Querier};

/// Formats the sum of two numbers as string.
#[pyfunction]
Expand All @@ -14,8 +14,8 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
#[pyfunction]
fn index(paths: Vec<String>, db_path: String, language: Language) -> PyResult<()> {
// TODO(@nohehf): Add a verbose mode to toggle the logs
println!("Indexing paths: {:?}", paths);
println!("Database path: {:?}", db_path);
// println!("Indexing paths: {:?}", paths);
// println!("Database path: {:?}", db_path);

let paths: Vec<std::path::PathBuf> =
paths.iter().map(|p| std::path::PathBuf::from(p)).collect();
Expand All @@ -27,32 +27,13 @@ fn index(paths: Vec<String>, db_path: String, language: Language) -> PyResult<()
)?)
}

/// Indexes the given paths into stack graphs, and stores the results in the given database.
#[pyfunction]
fn query_definition(reference: Position, db_path: String) -> PyResult<Vec<Position>> {
println!("Querying reference: {:?}", reference.to_string());
println!("Database path: {:?}", db_path);

let result = stack_graphs_wrapper::query_definition(reference.into(), &db_path)?;

// TODO(@nohehf): Check if we can flatten the results, see the QueryResult struct, we might be loosing some information
let positions: Vec<Position> = result
.into_iter()
.map(|r| r.targets)
.flatten()
.map(|t| t.into())
.collect();

Ok(positions)
}

/// A Python module implemented in Rust.
#[pymodule]
fn stack_graphs_python(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
m.add_function(wrap_pyfunction!(index, m)?)?;
m.add_function(wrap_pyfunction!(query_definition, m)?)?;
m.add_class::<Position>()?;
m.add_class::<Language>()?;
m.add_class::<Querier>()?;
Ok(())
}
6 changes: 2 additions & 4 deletions src/stack_graphs_wrapper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,11 @@ pub fn index(

pub fn query_definition(
reference: SourcePosition,
db_path: &str,
db_reader: &mut SQLiteReader,
) -> Result<Vec<QueryResult>, StackGraphsError> {
let mut db_read = SQLiteReader::open(&db_path).expect("failed to open database");

let reporter = ConsoleReporter::none();

let mut querier = Querier::new(&mut db_read, &reporter);
let mut querier = Querier::new(db_reader, &reporter);

// print_source_position(&reference);

Expand Down
5 changes: 4 additions & 1 deletion stack_graphs_python.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@ class Position:

def __init__(self, path: str, line: int, column: int) -> None: ...

class Querier:
def __init__(self, db_path: str) -> None: ...
def definitions(self, reference: Position) -> list[Position]: ...

def index(paths: list[str], db_path: str, language: Language) -> None: ...
def query_definition(reference: Position, db_path: str) -> list[Position]: ...
15 changes: 9 additions & 6 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# TODO(@nohehf): Make this a propper pytest test & run in CI
import os
from stack_graphs_python import index, query_definition, Position, Language
from stack_graphs_python import index, Querier, Position, Language

# index ./js_sample directory

# convert ./js_sample directory to absolute path
dir = os.path.abspath("./tests/js_sample")
db = os.path.abspath("./js_sample.db")
db_path = os.path.abspath("./db.sqlite")

print("Indexing directory: ", dir)
print("Database path: ", db)
print("Database path: ", db_path)

index([dir], db, language=Language.Python)
index([dir], db_path, language=Language.JavaScript)

source_reference: Position = Position(path=dir + "/index.js", line=2, column=12)
source_reference = Position(path=dir + "/index.js", line=2, column=12)

print("Querying definition for: ", source_reference.path)

results = query_definition(source_reference, db)
querier = Querier(db_path)

results = querier.definitions(source_reference)

print("Results: ", results)

Expand Down

0 comments on commit 96ebb43

Please sign in to comment.