"""Core chunking functions used by multiple modules."""
from __future__ import annotations
import re
from pathlib import Path
from typing import TYPE_CHECKING
from .languages import language_config_registry
__all__ = [
"chunk_file",
"chunk_text",
]
# Imported lazily below to avoid circular import with multi_language
from .metadata import MetadataExtractorFactory
from .parser import get_parser
from .types import CodeChunk, compute_file_id, compute_node_id, compute_symbol_id
if TYPE_CHECKING:
from tree_sitter import Node
def _extract_definition_name(node: Node, source: bytes) -> str | None:
"""Extract the definition name from an AST node.
Tries common field names used across languages:
- "name" (most common: Python, JS, TS, Go, Rust, etc.)
- "identifier" (some grammars)
- "declarator" then "name" (C/C++ style)
Returns None if no name can be extracted (anonymous definition).
"""
# Try direct "name" field first (most common)
name_node = getattr(node, "child_by_field_name", lambda _: None)("name")
if name_node is not None:
return source[name_node.start_byte : name_node.end_byte].decode(
"utf-8",
errors="ignore",
)
# Try "identifier" field (some grammars)
id_node = getattr(node, "child_by_field_name", lambda _: None)("identifier")
if id_node is not None:
return source[id_node.start_byte : id_node.end_byte].decode(
"utf-8",
errors="ignore",
)
# Try declarator pattern (C/C++ style: type declarator { name })
declarator = getattr(node, "child_by_field_name", lambda _: None)("declarator")
if declarator is not None:
decl_name = getattr(declarator, "child_by_field_name", lambda _: None)("name")
if decl_name is not None:
return source[decl_name.start_byte : decl_name.end_byte].decode(
"utf-8",
errors="ignore",
)
# Some declarators ARE the name directly (identifier type)
if getattr(declarator, "type", "") == "identifier":
return source[declarator.start_byte : declarator.end_byte].decode(
"utf-8",
errors="ignore",
)
return None
def _parse_route_segment(segment: str) -> tuple[str, str]:
"""Split a qualified route segment into node type and name."""
if ":" not in segment:
return segment, segment
node_type, name = segment.split(":", 1)
return node_type, name
def _normalize_kind(node_type: str) -> str:
"""Normalize language-specific node types to retrieval-friendly kinds."""
lowered = (node_type or "").lower()
if "constructor" in lowered:
return "constructor"
if "method" in lowered:
return "method"
if any(token in lowered for token in {"function", "func", "procedure"}):
return "function"
if "class" in lowered:
return "class"
if "interface" in lowered:
return "interface"
if "trait" in lowered:
return "trait"
if "enum" in lowered:
return "enum"
if "struct" in lowered:
return "struct"
if any(token in lowered for token in {"module", "namespace", "package"}):
return "module"
if "type" in lowered:
return "type"
return node_type
def _route_symbol_name(segment: str) -> str | None:
"""Extract a readable symbol name from a route segment."""
_node_type, name = _parse_route_segment(segment)
if not name or name.startswith("anon@"):
return None
return name
def _format_signature_text(signature: object) -> str | None:
"""Format signature metadata into a normalized string."""
if signature is None:
return None
if isinstance(signature, str):
return signature
if not isinstance(signature, dict):
return str(signature)
name = signature.get("name")
params = []
for param in signature.get("parameters", []):
if isinstance(param, dict):
param_text = param.get("name", "?")
if param.get("type"):
param_text += f": {param['type']}"
if param.get("default") is not None:
param_text += f" = {param['default']}"
else:
param_text = str(param)
params.append(param_text)
signature_text = f"({', '.join(params)})"
if name:
signature_text = f"{name}{signature_text}"
if signature.get("return_type"):
signature_text += f" -> {signature['return_type']}"
return signature_text
def _normalize_text_list(value: object) -> list[str]:
"""Normalize metadata values into a stable string list."""
if value is None:
return []
if isinstance(value, str):
values = [value]
elif isinstance(value, (list, tuple, set)):
values = [str(item) for item in value if item is not None]
else:
values = [str(value)]
return sorted(
dict.fromkeys(item.strip() for item in values if item and item.strip())
)
def _build_semantic_text(
chunk: CodeChunk, retrieval_metadata: dict[str, object]
) -> str:
"""Build retrieval-oriented semantic text for a chunk."""
lines = [
f"language: {chunk.language}",
f"file: {chunk.file_path or '<memory>'}",
f"kind: {retrieval_metadata['kind']}",
]
for key in (
"symbol",
"qualified_name",
"parent_symbol",
"semantic_path",
"signature_text",
):
value = retrieval_metadata.get(key)
if value:
label = key.replace("_", " ")
lines.append(f"{label}: {value}")
for key in ("imports", "exports", "dependencies"):
values = _normalize_text_list(retrieval_metadata.get(key))
if values:
lines.append(f"{key}: {', '.join(values)}")
docstring = chunk.metadata.get("docstring") if chunk.metadata else None
if docstring:
lines.append(f"docstring: {docstring}")
if chunk.content:
lines.append("content:")
lines.append(chunk.content)
return "\n".join(lines)
def _build_retrieval_metadata(chunk: CodeChunk) -> dict[str, object]:
"""Derive language-agnostic retrieval metadata from a chunk."""
metadata = chunk.metadata or {}
route_symbols = [
name
for name in (_route_symbol_name(segment) for segment in chunk.qualified_route)
if name
]
signature = metadata.get("signature")
symbol = None
if isinstance(signature, dict):
symbol = signature.get("name")
elif isinstance(metadata.get("symbol"), str):
symbol = metadata.get("symbol")
if not symbol and route_symbols:
symbol = route_symbols[-1]
if not symbol and chunk.parent_context:
symbol = chunk.parent_context.split(".")[-1]
symbol = str(symbol) if symbol else None
parent_symbol = route_symbols[-2] if len(route_symbols) > 1 else None
kind = _normalize_kind(chunk.node_type)
parent_node_types = [
_parse_route_segment(segment)[0] for segment in chunk.qualified_route[:-1]
]
if kind == "function" and any(
_normalize_kind(node_type) in {"class", "struct", "interface", "trait"}
for node_type in parent_node_types
):
kind = "method"
qualified_name = ".".join(route_symbols) if route_symbols else symbol
semantic_path = chunk.file_path or "<memory>"
if qualified_name:
semantic_path = f"{semantic_path}::{qualified_name}"
retrieval_metadata: dict[str, object] = {
"kind": kind,
"symbol": symbol,
"qualified_name": qualified_name,
"parent_symbol": parent_symbol,
"semantic_path": semantic_path,
"signature_text": _format_signature_text(signature),
"imports": _normalize_text_list(metadata.get("imports")),
"exports": _normalize_text_list(metadata.get("exports")),
"dependencies": _normalize_text_list(
metadata.get("dependencies", chunk.dependencies)
),
}
retrieval_metadata["semantic_text"] = _build_semantic_text(
chunk, retrieval_metadata
)
return retrieval_metadata
def _apply_retrieval_metadata(chunks: list[CodeChunk]) -> None:
"""Attach normalized retrieval metadata to each chunk in-place."""
for chunk in chunks:
if chunk.metadata is None:
chunk.metadata = {}
chunk.metadata.update(_build_retrieval_metadata(chunk))
def _walk(
node: Node,
source: bytes,
language: str,
parent_ctx: str | None = None,
parent_chunk: CodeChunk | None = None,
extractor=None,
analyzer=None,
parent_route: list[str] | None = None,
parent_qualified_route: list[str] | None = None,
) -> list[CodeChunk]:
"""Walk the AST and extract chunks based on language configuration."""
# Get language configuration
config = language_config_registry.get(language)
if not config:
# Fallback to hardcoded defaults for backward compatibility
if language in {"csharp", "c_sharp"}:
# Tree-sitter C# node types
chunk_types = {
"class_declaration",
"struct_declaration",
"interface_declaration",
"enum_declaration",
"method_declaration",
"constructor_declaration",
"property_declaration",
"field_declaration",
"record_declaration",
}
else:
chunk_types = {
"function_definition",
"class_definition",
"method_definition",
}
def should_chunk(node_type: str) -> bool:
return node_type in chunk_types
def should_ignore(_node_type: str) -> bool:
return False
else:
should_chunk = config.should_chunk_node
should_ignore = config.should_ignore_node # type: ignore[assignment]
# Go: ensure common declaration node types are chunked even if rules are minimal
if language == "go":
go_decl_like = {
"function_declaration",
"method_declaration",
"type_declaration",
"type_spec",
"const_declaration",
"var_declaration",
}
def should_chunk(node_type: str) -> bool: # type: ignore[no-redef]
return config.should_chunk_node(node_type) or node_type in go_decl_like
# For LISPy languages like Clojure, treat top-level list forms as chunks
if language == "clojure":
def should_chunk(node_type: str) -> bool: # type: ignore[no-redef]
return node_type == "list_lit" or config.should_chunk_node(node_type)
# For Dart, the grammar exposes separate signature/body nodes. Treat
# signatures as declarations for chunking.
elif language == "dart":
dart_signature_types = {
"function_signature",
"method_signature",
"getter_signature",
"setter_signature",
"constructor_signature",
"factory_constructor_signature",
}
dart_extra_decl_like = {"class_definition", "type_alias"}
def should_chunk(node_type: str) -> bool: # type: ignore[no-redef]
return (
config.should_chunk_node(node_type)
or node_type in dart_signature_types
or node_type in dart_extra_decl_like
)
chunks: list[CodeChunk] = []
current_chunk = None
current_qualified_route: list[str] | None = None
# Skip ignored nodes
if should_ignore(node.type):
return chunks
# Ensure route lists
parent_route = (parent_route or []).copy()
parent_qualified_route = (parent_qualified_route or []).copy()
# R special-cases: treat setClass/setMethod calls as chunks
force_chunk = False
r_call_name: str | None = None
if language == "r" and node.type == "call":
try:
callee = (getattr(node, "children", None) or [None])[0]
if callee is not None and getattr(callee, "type", None) == "identifier":
ident = source[callee.start_byte : callee.end_byte].decode(
"utf-8",
errors="ignore",
)
if ident in {"setClass", "setMethod", "setGeneric"}:
force_chunk = True
r_call_name = ident
except Exception:
pass
# Check if this node should be a chunk
if should_chunk(node.type) or force_chunk:
# Default span covers the current node
span_start = node.start_byte
span_end = node.end_byte
adjusted_node_type = node.type
# Dart: merge signature + body into a single declaration chunk
if language == "dart":
dart_sig_to_decl = {
"function_signature": "function_declaration",
"method_signature": "method_declaration",
"getter_signature": "getter_declaration",
"setter_signature": "setter_declaration",
"constructor_signature": "constructor_declaration",
"factory_constructor_signature": "factory_constructor",
}
if node.type in dart_sig_to_decl:
adjusted_node_type = dart_sig_to_decl[node.type]
# Find following function_body sibling under same parent
parent = getattr(node, "parent", None)
if parent is not None:
try:
children = list(parent.children)
idx = children.index(node)
for sib in children[idx + 1 :]:
if sib.type == "function_body":
span_end = sib.end_byte
break
except Exception:
pass
elif node.type == "class_definition":
# Normalize to expected name used in tests/config
adjusted_node_type = "class_declaration"
elif node.type == "type_alias":
# Normalize Dart type aliases to typedef_declaration for tests
adjusted_node_type = "typedef_declaration"
# Elixir: reinterpret certain call forms as declarations
elif language == "elixir":
if node.type == "call":
try:
for child in getattr(node, "children", []) or []:
if getattr(child, "type", None) == "identifier":
ident = source[child.start_byte : child.end_byte].decode(
"utf-8",
errors="ignore",
)
if ident in {"def", "defp", "defmacro", "defmacrop"}:
adjusted_node_type = "function_definition"
break
if ident == "defmodule":
adjusted_node_type = "module_definition"
break
if ident == "defprotocol":
adjusted_node_type = "protocol_definition"
break
if ident == "defimpl":
adjusted_node_type = "implementation_definition"
break
if ident == "defstruct":
adjusted_node_type = "struct_definition"
break
except Exception:
pass
elif language == "haskell":
# Normalize Haskell node variants to canonical names expected by tests
if node.type == "type_declaration" or node.type == "type_synomym":
adjusted_node_type = "type_synonym"
elif node.type == "class":
adjusted_node_type = "class_declaration"
elif node.type == "instance":
adjusted_node_type = "instance_declaration"
elif node.type == "header":
adjusted_node_type = "module_declaration"
elif language == "scala":
# Scala: detect case classes and adjust node type
if node.type == "class_definition":
# Check if this is a case class by examining first child
if node.children and node.children[0].type == "case":
adjusted_node_type = "case_class_definition"
elif node.type == "function_definition":
# Check if this function is inside a class/trait/object (making it a method)
parent = getattr(node, "parent", None)
while parent:
if parent.type in {
"class_definition",
"trait_definition",
"object_definition",
"template_body",
}:
adjusted_node_type = "method_definition"
break
parent = getattr(parent, "parent", None)
elif language == "julia":
# Map Julia assignment nodes that are actually function definitions
if node.type == "assignment":
# Check if left side is a call_expression (function signature)
for child in node.children:
if child.type == "call_expression":
adjusted_node_type = "short_function_definition"
break
elif node.type == "abstract_definition":
adjusted_node_type = "abstract_type_definition"
elif node.type == "primitive_definition":
adjusted_node_type = "primitive_type_definition"
elif language == "sql":
# Map tree-sitter SQL node types to expected test node types
if node.type in {
"insert",
"update",
"delete",
"select",
} or node.type.startswith("create_"):
adjusted_node_type = f"{node.type}_statement"
elif node.type == "ERROR":
# Handle ERROR nodes that might be CREATE PROCEDURE/FUNCTION
content = (
source[node.start_byte : node.end_byte]
.decode(
"utf-8",
errors="replace",
)
.lower()
)
if "create procedure" in content:
adjusted_node_type = "create_procedure_statement"
elif "create function" in content:
adjusted_node_type = "create_function_statement"
# else keep as ERROR (will be filtered out later)
# Clojure: reinterpret list forms as their defining form (defn, def, etc.)
if language == "clojure" and node.type == "list_lit":
# Use named children to skip parentheses and punctuation tokens
children = list(getattr(node, "named_children", []) or [])
if len(children) >= 1 and children[0].type == "sym_lit":
form_name = (
source[children[0].start_byte : children[0].end_byte]
.decode("utf-8", errors="replace")
.strip()
)
if form_name in {
"defn",
"defn-",
"def",
"defmacro",
"defprotocol",
"deftype",
"defrecord",
"defmulti",
"defmethod",
"defonce",
"defstruct",
}:
adjusted_node_type = form_name
# For R, include the name identifier for assignment-based function defs by
# expanding the span to include the full assignment expression
if language == "r" and node.type in {"function_definition"}:
parent = getattr(node, "parent", None)
if parent is not None and parent.type in {
"assignment",
"left_assignment",
"right_assignment",
"binary_operator",
}:
# Expand to the parent span so the chunk content includes the name and <-
span_start = min(span_start, parent.start_byte)
span_end = max(span_end, parent.end_byte)
# For R special call chunks, normalize node type to the callee name
if language == "r" and force_chunk and r_call_name:
adjusted_node_type = r_call_name
text = source[span_start:span_end].decode()
current_route = [*parent_route, adjusted_node_type]
# Build qualified_route with definition names for content-insensitive ID
start_line = node.start_point[0] + 1
def_name = _extract_definition_name(node, source)
if def_name:
qualified_name = f"{adjusted_node_type}:{def_name}"
else:
# Anonymous definition - use line number as fallback
qualified_name = f"{adjusted_node_type}:anon@{start_line}"
current_qualified_route = [*(parent_qualified_route or []), qualified_name]
current_chunk = CodeChunk(
language=language,
file_path="",
node_type=adjusted_node_type,
start_line=start_line,
end_line=(
# Estimate end line from span_end by walking to end_point if same node
node.end_point[0] + 1
if span_end == node.end_byte
else None # type: ignore[truthy-bool]
)
or (node.end_point[0] + 1),
byte_start=span_start,
byte_end=span_end,
parent_context=parent_ctx or "",
content=text,
parent_chunk_id=(parent_chunk.node_id if parent_chunk else None),
parent_route=current_route,
qualified_route=current_qualified_route,
)
# Extract metadata if extractors are available
if extractor or analyzer:
metadata = {}
if extractor:
# Extract signature
signature = extractor.extract_signature(node, source)
if signature:
metadata["signature"] = {
"name": signature.name,
"parameters": signature.parameters,
"return_type": signature.return_type,
"decorators": signature.decorators,
"modifiers": signature.modifiers,
}
# Extract docstring
docstring = extractor.extract_docstring(node, source)
if docstring:
metadata["docstring"] = docstring
# Extract dependencies
dependencies = extractor.extract_dependencies(node, source)
metadata["dependencies"] = sorted(dependencies) if dependencies else []
current_chunk.dependencies = (
sorted(dependencies) if dependencies else []
)
# Extract imports
imports = extractor.extract_imports(node, source)
if imports:
metadata["imports"] = imports
# Extract exports
exports = extractor.extract_exports(node, source)
if exports:
metadata["exports"] = sorted(exports)
# Extract calls with spans
calls = extractor.extract_calls(node, source)
if calls:
# Backward compatibility: extract just names
metadata["calls"] = [call["name"] for call in calls]
# New detailed format: include spans
metadata["call_spans"] = calls
if analyzer:
# Calculate complexity metrics
complexity = analyzer.analyze_complexity(node, source)
metadata["complexity"] = {
"cyclomatic": complexity.cyclomatic,
"cognitive": complexity.cognitive,
"nesting_depth": complexity.nesting_depth,
"lines_of_code": complexity.lines_of_code,
"logical_lines": complexity.logical_lines,
}
# Set metadata type field for all languages when metadata extraction is enabled
if language in {"typescript", "tsx"}:
# TypeScript-specific metadata type mapping
if adjusted_node_type == "interface_declaration":
metadata["type"] = "interface_declaration"
elif adjusted_node_type == "type_alias_declaration":
metadata["type"] = "type_alias_declaration"
elif adjusted_node_type == "enum_declaration":
metadata["type"] = "enum_declaration"
elif adjusted_node_type in {"internal_module", "module"}:
metadata["type"] = "namespace_declaration"
elif adjusted_node_type == "abstract_class_declaration":
metadata["type"] = "class_declaration"
metadata["abstract"] = True
else:
metadata["type"] = adjusted_node_type
else:
# For other languages, set type to node_type by default
metadata["type"] = adjusted_node_type
current_chunk.metadata = metadata
else:
# For compatibility, even if no extractors create an empty metadata dict
current_chunk.metadata = {}
chunks.append(current_chunk)
# Set better context for select languages
if language == "go":
try:
if adjusted_node_type in {
"function_declaration",
"method_declaration",
"type_spec",
"type_declaration",
}:
name_node = getattr(node, "child_by_field_name", lambda _n: None)(
"name",
)
if name_node is not None:
item_name = source[
name_node.start_byte : name_node.end_byte
].decode(
"utf-8",
errors="ignore",
)
# Assign entity name to parent_context for compatibility with tests
current_chunk.parent_context = item_name
except Exception:
pass
# Dart: also emit a synthetic 'widget_class' chunk for Flutter widgets
if (
language == "dart"
and adjusted_node_type == "class_declaration"
and ("extends StatelessWidget" in text or "extends StatefulWidget" in text)
):
widget_chunk = CodeChunk(
language=language,
file_path="",
node_type="widget_class",
start_line=current_chunk.start_line,
end_line=current_chunk.end_line,
byte_start=current_chunk.byte_start,
byte_end=current_chunk.byte_end,
parent_context=current_chunk.parent_context,
content=current_chunk.content,
parent_chunk_id=current_chunk.parent_chunk_id,
parent_route=[*current_route, "widget_class"],
)
chunks.append(widget_chunk)
# Vue: synthesize a component_definition from script contents
if language == "vue" and node.type == "script_element":
try:
script_text = text
if ("export default" in script_text) or (
"defineComponent" in script_text
):
comp_chunk = CodeChunk(
language=language,
file_path="",
node_type="component_definition",
start_line=current_chunk.start_line,
end_line=current_chunk.end_line,
byte_start=current_chunk.byte_start,
byte_end=current_chunk.byte_end,
parent_context="script_element",
content=current_chunk.content,
parent_chunk_id=current_chunk.parent_chunk_id,
parent_route=[*current_route, "component_definition"],
)
chunks.append(comp_chunk)
except Exception:
pass
# Svelte: synthesize reactive_statement chunks from script contents
if language == "svelte" and node.type == "script_element":
try:
# Extract raw script body between tags if present
script_text = text
body = script_text
gt = script_text.find(">")
end_tag = script_text.rfind("</")
if gt != -1 and end_tag != -1 and gt + 1 < end_tag:
body = script_text[gt + 1 : end_tag]
lines = body.splitlines()
for idx, line in enumerate(lines):
stripped = line.strip()
if stripped.startswith("$:"):
reactive_chunk = CodeChunk(
language=language,
file_path="",
node_type="reactive_statement",
start_line=current_chunk.start_line + idx,
end_line=current_chunk.start_line + idx,
byte_start=0,
byte_end=0,
parent_context="script_element",
content=line,
parent_chunk_id=current_chunk.parent_chunk_id,
parent_route=[*current_route, "reactive_statement"],
)
chunks.append(reactive_chunk)
except Exception:
pass
# Svelte: synthesize control-flow chunks by scanning entire file once at top-level
if language == "svelte" and parent_chunk is None:
try:
full_text = source.decode("utf-8", errors="replace")
lines = full_text.splitlines()
for idx, line in enumerate(lines, start=1):
stripped = line.strip()
cf_type = None
if stripped.startswith("{#if"):
cf_type = "if_block"
elif stripped.startswith("{#each"):
cf_type = "each_block"
elif stripped.startswith("{#await"):
cf_type = "await_block"
elif stripped.startswith("{#key"):
cf_type = "key_block"
if cf_type:
cf_chunk = CodeChunk(
language=language,
file_path="",
node_type=cf_type,
start_line=idx,
end_line=idx,
byte_start=0,
byte_end=0,
parent_context="template",
content=line,
parent_chunk_id=None,
parent_route=[cf_type],
)
chunks.append(cf_chunk)
except Exception:
pass
parent_ctx = node.type # nested functions, etc.
parent_route = current_route
parent_qualified_route = current_qualified_route
# Walk children with current chunk as parent
for child in node.children:
chunks.extend(
_walk(
child,
source,
language,
parent_ctx,
current_chunk or parent_chunk,
extractor,
analyzer,
parent_route=parent_route,
parent_qualified_route=parent_qualified_route,
),
)
# Julia-specific post-processing: merge preceding comments with definitions
if language == "julia":
chunks = _merge_julia_comments_with_definitions(chunks)
# MATLAB-specific post-processing: detect scripts
if language == "matlab":
chunks = _detect_matlab_scripts(chunks, node, source, parent_chunk)
return chunks
def _merge_julia_comments_with_definitions(chunks: list[CodeChunk]) -> list[CodeChunk]:
"""Merge Julia comment chunks with following definition chunks."""
if not chunks:
return chunks
merged_chunks = []
i = 0
while i < len(chunks):
current_chunk = chunks[i]
# Check if this is a line comment followed by a definition
if (
current_chunk.node_type == "line_comment"
and i + 1 < len(chunks)
and chunks[i + 1].node_type
in {
"struct_definition",
"function_definition",
"module_definition",
"macro_definition",
"macrocall_expression",
"abstract_definition",
"primitive_definition",
"abstract_type_definition",
"primitive_type_definition",
}
):
next_chunk = chunks[i + 1]
# Check if they're adjacent (comment right before definition)
if (
current_chunk.end_line + 1 == next_chunk.start_line
or current_chunk.end_line == next_chunk.start_line
):
# Merge the comment content with the definition content
merged_content = current_chunk.content + "\n" + next_chunk.content
# Create a new chunk with the merged content
merged_chunk = CodeChunk(
language=next_chunk.language,
file_path=next_chunk.file_path,
node_type=next_chunk.node_type,
start_line=current_chunk.start_line,
end_line=next_chunk.end_line,
byte_start=current_chunk.byte_start,
byte_end=next_chunk.byte_end,
parent_context=next_chunk.parent_context,
content=merged_content,
parent_chunk_id=next_chunk.parent_chunk_id,
parent_route=next_chunk.parent_route,
)
# Copy metadata from the definition chunk
if hasattr(next_chunk, "metadata"):
merged_chunk.metadata = next_chunk.metadata
merged_chunks.append(merged_chunk)
i += 2 # Skip both chunks
continue
# If not merging, just add the current chunk
merged_chunks.append(current_chunk)
i += 1
return merged_chunks
def _detect_matlab_scripts(
chunks: list[CodeChunk],
node,
source: bytes,
parent_chunk,
) -> list[CodeChunk]:
"""Detect MATLAB scripts and add script chunks when appropriate."""
# Only process at the source_file level (top level) with no parent chunk
if node.type != "source_file" or parent_chunk is not None:
return chunks
# Check if there are top-level statements that make this a script
has_top_level_code = False
has_functions_or_classes = any(
chunk.node_type in {"function_definition", "classdef", "class_definition"}
for chunk in chunks
)
# Look for top-level statements in the node children
for child in node.children:
if child.type in {"assignment", "function_call", "command", "comment"}:
has_top_level_code = True
break
# If there's top-level code, create a script chunk for the whole file
if has_top_level_code:
content = source.decode("utf-8", errors="replace")
script_chunk = CodeChunk(
language="matlab",
file_path="",
node_type="script",
start_line=1,
end_line=content.count("\n") + 1,
byte_start=0,
byte_end=len(source),
parent_context="",
content=content,
parent_chunk_id=None,
parent_route=["script"],
)
# Insert script chunk at the beginning
chunks.insert(0, script_chunk)
return chunks
def chunk_text(
text: str,
language: str,
file_path: str = "",
extract_metadata: bool = True,
include_retrieval_metadata: bool = False,
) -> list[CodeChunk]:
"""Parse text and return a list of `CodeChunk`.
Args:
text: Source code text to chunk
language: Programming language
file_path: Path to the file (optional)
extract_metadata: Whether to extract metadata (default: True)
include_retrieval_metadata: Whether to add retrieval-oriented metadata
Returns:
List of CodeChunk objects with optional metadata
"""
parser = get_parser(language)
src = text.encode()
tree = parser.parse(src)
# Create metadata extractors if requested
extractor = None
analyzer = None
if extract_metadata:
extractor = MetadataExtractorFactory.create_extractor(language)
analyzer = MetadataExtractorFactory.create_analyzer(language)
chunks = _walk(
tree.root_node,
src,
language,
extractor=extractor,
analyzer=analyzer,
)
# Build mapping from temporary IDs (no path) to final IDs (with path)
tmp_to_final: dict[str, str] = {}
for c in chunks:
tmp_id = compute_node_id("", c.language, c.parent_route, c.content)
final_id = compute_node_id(
file_path,
c.language,
c.parent_route,
c.content,
)
tmp_to_final[tmp_id] = final_id
for c in chunks:
c.file_path = file_path
# update file/node ids now that path is known
c.file_id = compute_file_id(file_path)
c.node_id = compute_node_id(
file_path,
c.language,
c.parent_route,
c.content,
)
c.chunk_id = c.node_id
# fix parent id if it was set using temporary id
if c.parent_chunk_id and c.parent_chunk_id in tmp_to_final:
c.parent_chunk_id = tmp_to_final[c.parent_chunk_id]
# recompute symbol id if signature present
sig = c.metadata.get("signature") if c.metadata else None
if sig and sig.get("name"):
c.symbol_id = compute_symbol_id(language, file_path, sig["name"])
if include_retrieval_metadata:
_apply_retrieval_metadata(chunks)
for c in chunks:
symbol = c.metadata.get("symbol") if c.metadata else None
if symbol:
c.symbol_id = compute_symbol_id(language, file_path, str(symbol))
return chunks
[docs]
def chunk_file(
path: str | Path,
language: str,
extract_metadata: bool = True,
include_retrieval_metadata: bool = False,
) -> list[CodeChunk]:
"""Parse the file and return a list of `CodeChunk`.
Args:
path: Path to the file to chunk
language: Programming language
extract_metadata: Whether to extract metadata (default: True)
include_retrieval_metadata: Whether to add retrieval-oriented metadata
Returns:
List of CodeChunk objects with optional metadata
"""
# Read file contents with robust decoding
p = Path(path)
try:
src = p.read_text(encoding="utf-8")
except UnicodeDecodeError:
# Fallback: replace invalid bytes to avoid crashing on bad encodings
src = p.read_bytes().decode("utf-8", errors="replace")
# Special handling for R Markdown: extract embedded R code blocks
if language == "r" and p.suffix.lower() in {".rmd", ".rmarkdown"}:
from .multi_language import ( # local import to avoid cycle
MultiLanguageProcessorImpl,
)
ml = MultiLanguageProcessorImpl()
# Prefer robust Rmd extraction that supports ```{r chunk-name}
pattern = re.compile(r"```\{r[^}]*\}\s*\r?\n([\s\S]*?)\r?\n```", re.DOTALL)
snippets = [(m.group(1), m.start(1), m.end(1)) for m in pattern.finditer(src)]
# Fallback to generic markdown extractor if custom pattern finds nothing
if not snippets:
snippets = ml.extract_embedded_code(
src,
host_language="markdown",
target_language="r",
)
all_chunks: list[CodeChunk] = []
for code, start, end in snippets:
# Derive pseudo file name for chunk id stability
pseudo_path = f"{p}:{start}-{end}"
all_chunks.extend(
chunk_text(
code,
"r",
pseudo_path,
extract_metadata=extract_metadata,
include_retrieval_metadata=include_retrieval_metadata,
),
)
return all_chunks
return chunk_text(
src,
language,
str(path),
extract_metadata=extract_metadata,
include_retrieval_metadata=include_retrieval_metadata,
)