diff --git a/libcst/__init__.py b/libcst/__init__.py index ff63033d5..8a51befb2 100644 --- a/libcst/__init__.py +++ b/libcst/__init__.py @@ -183,6 +183,7 @@ MatchValue, NameItem, Nonlocal, + ParamSpec, Pass, Raise, Return, @@ -190,6 +191,11 @@ SimpleStatementSuite, Try, TryStar, + TypeAlias, + TypeParameters, + TypeParam, + TypeVar, + TypeVarTuple, While, With, WithItem, @@ -438,4 +444,10 @@ "VisitorMetadataProvider", "MetadataDependent", "MetadataWrapper", + "TypeVar", + "TypeVarTuple", + "ParamSpec", + "TypeParam", + "TypeParameters", + "TypeAlias", ] diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index de5161fa6..4b363cf7e 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -1713,6 +1713,9 @@ class FunctionDef(BaseCompoundStatement): #: The function body. body: BaseSuite + #: An optional declaration of type parameters. + type_parameters: Optional[TypeParameters] = None + #: Sequence of decorators applied to this function. Decorators are listed in #: order that they appear in source (top to bottom) as apposed to the order #: that they are applied to the function at runtime. @@ -1738,10 +1741,14 @@ class FunctionDef(BaseCompoundStatement): #: Whitespace after the ``def`` keyword and before the function name. whitespace_after_def: SimpleWhitespace = SimpleWhitespace.field(" ") - #: Whitespace after the function name and before the opening parenthesis for - #: the parameters. + #: Whitespace after the function name and before the type parameters or the opening + #: parenthesis for the parameters. whitespace_after_name: SimpleWhitespace = SimpleWhitespace.field("") + #: Whitespace between the type parameters and the opening parenthesis for the + #: (non-type) parameters. + whitespace_after_type_parameters: SimpleWhitespace = SimpleWhitespace.field("") + #: Whitespace after the opening parenthesis for the parameters but before #: the first param itself. whitespace_before_params: BaseParenthesizableWhitespace = SimpleWhitespace.field("") @@ -1758,6 +1765,15 @@ def _validate(self) -> None: "There must be at least one space between 'def' and name." ) + if ( + self.type_parameters is None + and self.whitespace_after_type_parameters != SimpleWhitespace("") + ): + raise CSTValidationError( + "whitespace_after_type_parameters must be empty if there are no type " + "parameters in FunctionDef" + ) + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "FunctionDef": return FunctionDef( leading_lines=visit_sequence( @@ -1777,6 +1793,15 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "FunctionDef": whitespace_after_name=visit_required( self, "whitespace_after_name", self.whitespace_after_name, visitor ), + type_parameters=visit_optional( + self, "type_parameters", self.type_parameters, visitor + ), + whitespace_after_type_parameters=visit_required( + self, + "whitespace_after_type_parameters", + self.whitespace_after_type_parameters, + visitor, + ), whitespace_before_params=visit_required( self, "whitespace_before_params", self.whitespace_before_params, visitor ), @@ -1805,6 +1830,10 @@ def _codegen_impl(self, state: CodegenState) -> None: self.whitespace_after_def._codegen(state) self.name._codegen(state) self.whitespace_after_name._codegen(state) + type_params = self.type_parameters + if type_params is not None: + type_params._codegen(state) + self.whitespace_after_type_parameters._codegen(state) state.add_token("(") self.whitespace_before_params._codegen(state) self.params._codegen(state) @@ -1837,6 +1866,9 @@ class ClassDef(BaseCompoundStatement): #: The class body. body: BaseSuite + #: An optional declaration of type parameters. + type_parameters: Optional[TypeParameters] = None + #: Sequence of base classes this class inherits from. bases: Sequence[Arg] = () @@ -1866,10 +1898,14 @@ class ClassDef(BaseCompoundStatement): #: Whitespace after the ``class`` keyword and before the class name. whitespace_after_class: SimpleWhitespace = SimpleWhitespace.field(" ") - #: Whitespace after the class name and before the opening parenthesis for - #: the bases and keywords. + #: Whitespace after the class name and before the type parameters or the opening + #: parenthesis for the bases and keywords. whitespace_after_name: SimpleWhitespace = SimpleWhitespace.field("") + #: Whitespace between type parameters and opening parenthesis for the bases and + #: keywords. + whitespace_after_type_parameters: SimpleWhitespace = SimpleWhitespace.field("") + #: Whitespace after the closing parenthesis or class name and before #: the colon. whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("") @@ -1879,6 +1915,14 @@ def _validate_whitespace(self) -> None: raise CSTValidationError( "There must be at least one space between 'class' and name." ) + if ( + self.type_parameters is None + and self.whitespace_after_type_parameters != SimpleWhitespace("") + ): + raise CSTValidationError( + "whitespace_after_type_parameters must be empty if there are no type" + "parameters in a ClassDef" + ) def _validate_parens(self) -> None: if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: @@ -1921,6 +1965,15 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ClassDef": whitespace_after_name=visit_required( self, "whitespace_after_name", self.whitespace_after_name, visitor ), + type_parameters=visit_optional( + self, "type_parameters", self.type_parameters, visitor + ), + whitespace_after_type_parameters=visit_required( + self, + "whitespace_after_type_parameters", + self.whitespace_after_type_parameters, + visitor, + ), lpar=visit_sentinel(self, "lpar", self.lpar, visitor), bases=visit_sequence(self, "bases", self.bases, visitor), keywords=visit_sequence(self, "keywords", self.keywords, visitor), @@ -1945,6 +1998,10 @@ def _codegen_impl(self, state: CodegenState) -> None: # noqa: C901 self.whitespace_after_class._codegen(state) self.name._codegen(state) self.whitespace_after_name._codegen(state) + type_params = self.type_parameters + if type_params is not None: + type_params._codegen(state) + self.whitespace_after_type_parameters._codegen(state) lpar = self.lpar if isinstance(lpar, MaybeSentinel): if self.bases or self.keywords: @@ -3476,3 +3533,283 @@ def _codegen_impl(self, state: CodegenState) -> None: pats = self.patterns for idx, pat in enumerate(pats): pat._codegen(state, default_separator=idx + 1 < len(pats)) + + +@add_slots +@dataclass(frozen=True) +class TypeVar(CSTNode): + """ + A simple (non-variadic) type variable. + + Note: this node represents type a variable when declared using PEP-695 syntax. + """ + + #: The name of the type variable. + name: Name + + #: An optional bound on the type. + bound: Optional[BaseExpression] = None + + #: The colon used to separate the name and bound. If not specified, + #: :class:`MaybeSentinel` will be replaced with a colon if there is a bound, + #: otherwise will be left empty. + colon: Union[Colon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _codegen_impl(self, state: CodegenState) -> None: + with state.record_syntactic_position(self): + self.name._codegen(state) + bound = self.bound + colon = self.colon + if not isinstance(colon, MaybeSentinel): + colon._codegen(state) + else: + if bound is not None: + state.add_token(": ") + + if bound is not None: + bound._codegen(state) + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeVar": + return TypeVar( + name=visit_required(self, "name", self.name, visitor), + colon=visit_sentinel(self, "colon", self.colon, visitor), + bound=visit_optional(self, "bound", self.bound, visitor), + ) + + +@add_slots +@dataclass(frozen=True) +class TypeVarTuple(CSTNode): + """ + A variadic type variable. + """ + + #: The name of this type variable. + name: Name + + #: The (optional) whitespace between the star declaring this type variable as + #: variadic, and the variable's name. + whitespace_after_star: SimpleWhitespace = SimpleWhitespace.field(" ") + + def _codegen_impl(self, state: CodegenState) -> None: + with state.record_syntactic_position(self): + state.add_token("*") + self.whitespace_after_star._codegen(state) + self.name._codegen(state) + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeVarTuple": + return TypeVarTuple( + name=visit_required(self, "name", self.name, visitor), + whitespace_after_star=visit_required( + self, "whitespace_after_star", self.whitespace_after_star, visitor + ), + ) + + +@add_slots +@dataclass(frozen=True) +class ParamSpec(CSTNode): + """ + A parameter specification. + + Note: this node represents a parameter specification when declared using PEP-695 + syntax. + """ + + #: The name of this parameter specification. + name: Name + + #: The (optional) whitespace between the double star declaring this type variable as + #: a parameter specification, and the name. + whitespace_after_star: SimpleWhitespace = SimpleWhitespace.field(" ") + + def _codegen_impl(self, state: CodegenState) -> None: + with state.record_syntactic_position(self): + state.add_token("**") + self.whitespace_after_star._codegen(state) + self.name._codegen(state) + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ParamSpec": + return ParamSpec( + name=visit_required(self, "name", self.name, visitor), + whitespace_after_star=visit_required( + self, "whitespace_after_star", self.whitespace_after_star, visitor + ), + ) + + +@add_slots +@dataclass(frozen=True) +class TypeParam(CSTNode): + """ + A single type parameter that is contained in a :class:`TypeParameters` list. + """ + + #: The actual parameter. + param: Union[TypeVar, TypeVarTuple, ParamSpec] + + #: A trailing comma. If one is not provided, :class:`MaybeSentinel` will be replaced + #: with a comma only if a comma is required. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None: + self.param._codegen(state) + comma = self.comma + if isinstance(comma, MaybeSentinel): + if default_comma: + state.add_token(", ") + else: + comma._codegen(state) + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeParam": + return TypeParam( + param=visit_required(self, "param", self.param, visitor), + comma=visit_sentinel(self, "comma", self.comma, visitor), + ) + + +@add_slots +@dataclass(frozen=True) +class TypeParameters(CSTNode): + """ + Type parameters when specified with PEP-695 syntax. + + This node captures all specified parameters that are enclosed with square brackets. + """ + + #: The parameters within the square brackets. + params: Sequence[TypeParam] = () + + #: Opening square bracket that marks the start of these parameters. + lbracket: LeftSquareBracket = LeftSquareBracket.field() + #: Closing square bracket that marks the end of these parameters. + rbracket: RightSquareBracket = RightSquareBracket.field() + + def _codegen_impl(self, state: CodegenState) -> None: + self.lbracket._codegen(state) + params_len = len(self.params) + for idx, param in enumerate(self.params): + param._codegen(state, default_comma=idx + 1 < params_len) + self.rbracket._codegen(state) + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeParameters": + return TypeParameters( + lbracket=visit_required(self, "lbracket", self.lbracket, visitor), + params=visit_sequence(self, "params", self.params, visitor), + rbracket=visit_required(self, "rbracket", self.rbracket, visitor), + ) + + +@add_slots +@dataclass(frozen=True) +class TypeAlias(BaseSmallStatement): + """ + A type alias statement. + + This node represents the ``type`` statement as specified initially by PEP-695. + Example: ``type ListOrSet[T] = list[T] | set[T]``. + """ + + #: The name being introduced in this statement. + name: Name + + #: Everything on the right hand side of the ``=``. + value: BaseExpression + + #: An optional list of type parameters, specified after the name. + type_parameters: Optional[TypeParameters] = None + + #: Whitespace between the ``type`` soft keyword and the name. + whitespace_after_type: SimpleWhitespace = SimpleWhitespace.field(" ") + + #: Whitespace between the name and the type parameters (if they exist) or the ``=``. + #: If not specified, :class:`MaybeSentinel` will be replaced with a single space if + #: there are no type parameters, otherwise no spaces. + whitespace_after_name: Union[ + SimpleWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + + #: Whitespace between the type parameters and the ``=``. Always empty if there are + #: no type parameters. If not specified, :class:`MaybeSentinel` will be replaced + #: with a single space if there are type parameters. + whitespace_after_type_parameters: Union[ + SimpleWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + + #: Whitespace between the ``=`` and the value. + whitespace_after_equals: SimpleWhitespace = SimpleWhitespace.field(" ") + + #: Optional semicolon when this is used in a statement line. This semicolon + #: owns the whitespace on both sides of it when it is used. + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + if ( + self.type_parameters is None + and self.whitespace_after_type_parameters + not in { + SimpleWhitespace(""), + MaybeSentinel.DEFAULT, + } + ): + raise CSTValidationError( + "whitespace_after_type_parameters must be empty when there are no type parameters in a TypeAlias" + ) + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeAlias": + return TypeAlias( + whitespace_after_type=visit_required( + self, "whitespace_after_type", self.whitespace_after_type, visitor + ), + name=visit_required(self, "name", self.name, visitor), + whitespace_after_name=visit_sentinel( + self, "whitespace_after_name", self.whitespace_after_name, visitor + ), + type_parameters=visit_optional( + self, "type_parameters", self.type_parameters, visitor + ), + whitespace_after_type_parameters=visit_sentinel( + self, + "whitespace_after_type_parameters", + self.whitespace_after_type_parameters, + visitor, + ), + whitespace_after_equals=visit_required( + self, "whitespace_after_equals", self.whitespace_after_equals, visitor + ), + value=visit_required(self, "value", self.value, visitor), + semicolon=visit_sentinel(self, "semicolon", self.semicolon, visitor), + ) + + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: + with state.record_syntactic_position(self): + state.add_token("type") + self.whitespace_after_type._codegen(state) + self.name._codegen(state) + ws_after_name = self.whitespace_after_name + if isinstance(ws_after_name, MaybeSentinel): + if self.type_parameters is None: + state.add_token(" ") + else: + ws_after_name._codegen(state) + + ws_after_type_params = self.whitespace_after_type_parameters + if self.type_parameters is not None: + self.type_parameters._codegen(state) + if isinstance(ws_after_type_params, MaybeSentinel): + state.add_token(" ") + else: + ws_after_type_params._codegen(state) + + state.add_token("=") + self.whitespace_after_equals._codegen(state) + self.value._codegen(state) + + semi = self.semicolon + if isinstance(semi, MaybeSentinel): + if default_semicolon: + state.add_token("; ") + else: + semi._codegen(state) diff --git a/native/libcst/src/nodes/statement.rs b/native/libcst/src/nodes/statement.rs index 65006ab36..4c52aea00 100644 --- a/native/libcst/src/nodes/statement.rs +++ b/native/libcst/src/nodes/statement.rs @@ -304,6 +304,7 @@ pub enum SmallStatement<'a> { Nonlocal(Nonlocal<'a>), AugAssign(AugAssign<'a>), Del(Del<'a>), + TypeAlias(TypeAlias<'a>), } impl<'r, 'a> DeflatedSmallStatement<'r, 'a> { @@ -324,6 +325,7 @@ impl<'r, 'a> DeflatedSmallStatement<'r, 'a> { Self::Nonlocal(l) => Self::Nonlocal(l.with_semicolon(semicolon)), Self::AugAssign(a) => Self::AugAssign(a.with_semicolon(semicolon)), Self::Del(d) => Self::Del(d.with_semicolon(semicolon)), + Self::TypeAlias(t) => Self::TypeAlias(t.with_semicolon(semicolon)), } } } @@ -3332,3 +3334,235 @@ impl<'r, 'a> Inflate<'a> for DeflatedMatchOr<'r, 'a> { }) } } + +#[cst_node] +pub struct TypeVar<'a> { + pub name: Name<'a>, + pub bound: Option>>, + pub colon: Option>, +} + +impl<'a> Codegen<'a> for TypeVar<'a> { + fn codegen(&self, state: &mut CodegenState<'a>) { + self.name.codegen(state); + self.colon.codegen(state); + if let Some(bound) = &self.bound { + bound.codegen(state); + } + } +} + +impl<'r, 'a> Inflate<'a> for DeflatedTypeVar<'r, 'a> { + type Inflated = TypeVar<'a>; + fn inflate(self, config: &Config<'a>) -> Result { + let name = self.name.inflate(config)?; + let colon = self.colon.inflate(config)?; + let bound = self.bound.inflate(config)?; + Ok(Self::Inflated { name, bound, colon }) + } +} + +#[cst_node] +pub struct TypeVarTuple<'a> { + pub name: Name<'a>, + + pub whitespace_after_star: SimpleWhitespace<'a>, + + pub(crate) star_tok: TokenRef<'a>, +} + +impl<'a> Codegen<'a> for TypeVarTuple<'a> { + fn codegen(&self, state: &mut CodegenState<'a>) { + state.add_token("*"); + self.whitespace_after_star.codegen(state); + self.name.codegen(state); + } +} + +impl<'r, 'a> Inflate<'a> for DeflatedTypeVarTuple<'r, 'a> { + type Inflated = TypeVarTuple<'a>; + fn inflate(self, config: &Config<'a>) -> Result { + let whitespace_after_star = + parse_simple_whitespace(config, &mut self.star_tok.whitespace_after.borrow_mut())?; + let name = self.name.inflate(config)?; + Ok(Self::Inflated { + name, + whitespace_after_star, + }) + } +} + +#[cst_node] +pub struct ParamSpec<'a> { + pub name: Name<'a>, + + pub whitespace_after_star: SimpleWhitespace<'a>, + + pub(crate) star_tok: TokenRef<'a>, +} + +impl<'a> Codegen<'a> for ParamSpec<'a> { + fn codegen(&self, state: &mut CodegenState<'a>) { + state.add_token("**"); + self.whitespace_after_star.codegen(state); + self.name.codegen(state); + } +} + +impl<'r, 'a> Inflate<'a> for DeflatedParamSpec<'r, 'a> { + type Inflated = ParamSpec<'a>; + fn inflate(self, config: &Config<'a>) -> Result { + let whitespace_after_star = + parse_simple_whitespace(config, &mut self.star_tok.whitespace_after.borrow_mut())?; + let name = self.name.inflate(config)?; + Ok(Self::Inflated { + name, + whitespace_after_star, + }) + } +} + +#[cst_node(Inflate, Codegen)] +pub enum TypeParamItem<'a> { + TypeVar(TypeVar<'a>), + TypeVarTuple(TypeVarTuple<'a>), + ParamSpec(ParamSpec<'a>), +} + +#[cst_node] +pub struct TypeParam<'a> { + pub param: TypeParamItem<'a>, + pub comma: Option>, +} + +impl<'a> Codegen<'a> for TypeParam<'a> { + fn codegen(&self, state: &mut CodegenState<'a>) { + self.param.codegen(state); + } +} + +impl<'r, 'a> Inflate<'a> for DeflatedTypeParam<'r, 'a> { + type Inflated = TypeParam<'a>; + fn inflate(self, config: &Config<'a>) -> Result { + let param = self.param.inflate(config)?; + let comma = self.comma.inflate(config)?; + Ok(Self::Inflated { param, comma }) + } +} + +#[cst_node] +pub struct TypeParameters<'a> { + pub params: Vec>, + + pub lbracket: LeftSquareBracket<'a>, + pub rbracket: RightSquareBracket<'a>, +} + +impl<'a> Codegen<'a> for TypeParameters<'a> { + fn codegen(&self, state: &mut CodegenState<'a>) { + self.lbracket.codegen(state); + let params_len = self.params.len(); + for (idx, param) in self.params.iter().enumerate() { + param.codegen(state); + if idx + 1 < params_len && param.comma.is_none() { + state.add_token(", "); + } + } + self.rbracket.codegen(state); + } +} + +impl<'r, 'a> Inflate<'a> for DeflatedTypeParameters<'r, 'a> { + type Inflated = TypeParameters<'a>; + fn inflate(self, config: &Config<'a>) -> Result { + let lbracket = self.lbracket.inflate(config)?; + let params = self.params.inflate(config)?; + let rbracket = self.rbracket.inflate(config)?; + Ok(Self::Inflated { + params, + lbracket, + rbracket, + }) + } +} + +#[cst_node] +pub struct TypeAlias<'a> { + pub name: Name<'a>, + pub value: Box>, + pub type_parameters: Option>, + + pub whitespace_after_type: SimpleWhitespace<'a>, + pub whitespace_after_name: Option>, + pub whitespace_after_type_parameters: Option>, + pub whitespace_after_equals: SimpleWhitespace<'a>, + pub semicolon: Option>, + + pub(crate) type_tok: TokenRef<'a>, + pub(crate) lbracket_tok: Option>, + pub(crate) equals_tok: TokenRef<'a>, + pub(crate) semicolon_tok: TokenRef<'a>, +} + +impl<'a> Codegen<'a> for TypeAlias<'a> { + fn codegen(&self, state: &mut CodegenState<'a>) { + state.add_token("type"); + self.whitespace_after_type.codegen(state); + self.name.codegen(state); + if self.whitespace_after_name.is_none() && self.type_parameters.is_none() { + state.add_token(" "); + } else { + self.whitespace_after_name.codegen(state); + } + if self.type_parameters.is_some() { + self.type_parameters.codegen(state); + self.whitespace_after_type_parameters.codegen(state); + } + state.add_token("="); + self.whitespace_after_equals.codegen(state); + self.semicolon.codegen(state); + } +} + +impl<'r, 'a> Inflate<'a> for DeflatedTypeAlias<'r, 'a> { + type Inflated = TypeAlias<'a>; + fn inflate(self, config: &Config<'a>) -> Result { + let whitespace_after_type = + parse_simple_whitespace(config, &mut self.type_tok.whitespace_after.borrow_mut())?; + let name = self.name.inflate(config)?; + let whitespace_after_name = Some(if let Some(tok) = self.lbracket_tok { + parse_simple_whitespace(config, &mut tok.whitespace_before.borrow_mut()) + } else { + parse_simple_whitespace(config, &mut self.equals_tok.whitespace_before.borrow_mut()) + }?); + let type_parameters = self.type_parameters.inflate(config)?; + let whitespace_after_type_parameters = if type_parameters.is_some() { + Some(parse_simple_whitespace( + config, + &mut self.equals_tok.whitespace_before.borrow_mut(), + )?) + } else { + None + }; + let whitespace_after_equals = + parse_simple_whitespace(config, &mut self.equals_tok.whitespace_after.borrow_mut())?; + let value = self.value.inflate(config)?; + let semicolon = self.semicolon.inflate(config)?; + Ok(Self::Inflated { + name, + value, + type_parameters, + whitespace_after_type, + whitespace_after_name, + whitespace_after_type_parameters, + whitespace_after_equals, + semicolon, + }) + } +} + +impl<'r, 'a> DeflatedTypeAlias<'r, 'a> { + pub fn with_semicolon(self, semicolon: Option>) -> Self { + Self { semicolon, ..self } + } +}