Skip to content

Commit b17d4ff

Browse files
authored
Merge pull request #1093 from kngwyu/iterator-example
Improve lifetime insertions for #[pyproto]
2 parents c81013b + c4d9ab2 commit b17d4ff

File tree

7 files changed

+79
-66
lines changed

7 files changed

+79
-66
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
3939
[#1058](https://github.com/PyO3/pyo3/pull/1058). [#1059](https://github.com/PyO3/pyo3/pull/1059)
4040
- Allows `&Self` as a `#[pymethods]` argument again. [#1071](https://github.com/PyO3/pyo3/pull/1071)
4141
- Fix best-effort build against PyPy 3.6. #[1092](https://github.com/PyO3/pyo3/pull/1092)
42+
- Improve lifetime elision in `#[pyproto]`. [#1093](https://github.com/PyO3/pyo3/pull/1093)
4243

4344
## [0.11.1] - 2020-06-30
4445
### Added

guide/src/class.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -923,8 +923,8 @@ struct MyIterator {
923923

924924
#[pyproto]
925925
impl PyIterProtocol for MyIterator {
926-
fn __iter__(slf: PyRef<Self>) -> Py<MyIterator> {
927-
slf.into()
926+
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
927+
slf
928928
}
929929
fn __next__(mut slf: PyRefMut<Self>) -> Option<PyObject> {
930930
slf.iter.next()
@@ -948,8 +948,8 @@ struct Iter {
948948

949949
#[pyproto]
950950
impl PyIterProtocol for Iter {
951-
fn __iter__(slf: PyRefMut<Self>) -> Py<Iter> {
952-
slf.into()
951+
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
952+
slf
953953
}
954954

955955
fn __next__(mut slf: PyRefMut<Self>) -> Option<usize> {
@@ -964,7 +964,7 @@ struct Container {
964964

965965
#[pyproto]
966966
impl PyIterProtocol for Container {
967-
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<Iter>> {
967+
fn __iter__(slf: PyRef<Self>) -> PyResult<Py<Iter>> {
968968
let iter = Iter {
969969
inner: slf.iter.clone().into_iter(),
970970
};

pyo3-derive-backend/src/defs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// Copyright (c) 2017-present PyO3 Project and Contributors
2-
use crate::func::MethodProto;
2+
use crate::proto_method::MethodProto;
33

44
/// Predicates for `#[pyproto]`.
55
pub struct Proto {

pyo3-derive-backend/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#![recursion_limit = "1024"]
55

66
mod defs;
7-
mod func;
87
mod konst;
98
mod method;
109
mod module;
10+
mod proto_method;
1111
mod pyclass;
1212
mod pyfunction;
1313
mod pyimpl;

pyo3-derive-backend/src/func.rs renamed to pyo3-derive-backend/src/proto_method.rs

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use syn::Token;
66

77
// TODO:
88
// Add lifetime support for args with Rptr
9-
109
#[derive(Debug)]
1110
pub enum MethodProto {
1211
Free {
@@ -77,7 +76,11 @@ pub(crate) fn impl_method_proto(
7776
) -> TokenStream {
7877
let ret_ty = match &sig.output {
7978
syn::ReturnType::Default => quote! { () },
80-
syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
79+
syn::ReturnType::Type(_, ty) => {
80+
let mut ty = ty.clone();
81+
insert_lifetime(&mut ty);
82+
ty.to_token_stream()
83+
}
8184
};
8285

8386
match *meth {
@@ -106,22 +109,7 @@ pub(crate) fn impl_method_proto(
106109
let p: syn::Path = syn::parse_str(proto).unwrap();
107110

108111
let slf_name = syn::Ident::new(arg, Span::call_site());
109-
let mut slf_ty = get_arg_ty(sig, 0);
110-
111-
// update the type if no lifetime was given:
112-
// PyRef<Self> --> PyRef<'p, Self>
113-
if let syn::Type::Path(ref mut path) = slf_ty {
114-
if let syn::PathArguments::AngleBracketed(ref mut args) =
115-
path.path.segments[0].arguments
116-
{
117-
if let syn::GenericArgument::Lifetime(_) = args.args[0] {
118-
} else {
119-
let lt = syn::parse_quote! {'p};
120-
args.args.insert(0, lt);
121-
}
122-
}
123-
}
124-
112+
let slf_ty = get_arg_ty(sig, 0);
125113
let tmp: syn::ItemFn = syn::parse_quote! {
126114
fn test(&self) -> <#cls as #p<'p>>::Result {}
127115
};
@@ -336,38 +324,62 @@ pub(crate) fn impl_method_proto(
336324
}
337325
}
338326

339-
// TODO: better arg ty detection
327+
/// Some hacks for arguments: get `T` from `Option<T>` and insert lifetime
340328
fn get_arg_ty(sig: &syn::Signature, idx: usize) -> syn::Type {
341-
let mut ty = match sig.inputs[idx] {
342-
syn::FnArg::Typed(ref cap) => {
343-
match *cap.ty {
344-
syn::Type::Path(ref ty) => {
345-
// use only last path segment for Option<>
346-
let seg = ty.path.segments.last().unwrap().clone();
347-
if seg.ident == "Option" {
348-
if let syn::PathArguments::AngleBracketed(ref data) = seg.arguments {
349-
if let Some(pair) = data.args.last() {
350-
match pair {
351-
syn::GenericArgument::Type(ref ty) => return ty.clone(),
352-
_ => panic!("Option only accepted for concrete types"),
353-
}
354-
};
355-
}
356-
}
357-
*cap.ty.clone()
329+
fn get_option_ty(path: &syn::Path) -> Option<syn::Type> {
330+
let seg = path.segments.last()?;
331+
if seg.ident == "Option" {
332+
if let syn::PathArguments::AngleBracketed(ref data) = seg.arguments {
333+
if let Some(syn::GenericArgument::Type(ref ty)) = data.args.last() {
334+
return Some(ty.to_owned());
358335
}
359-
_ => *cap.ty.clone(),
360336
}
361337
}
362-
_ => panic!("fn arg type is not supported"),
338+
None
339+
}
340+
341+
let mut ty = match &sig.inputs[idx] {
342+
syn::FnArg::Typed(ref cap) => match &*cap.ty {
343+
// For `Option<T>`, we use `T` as an associated type for the protocol.
344+
syn::Type::Path(ref ty) => get_option_ty(&ty.path).unwrap_or_else(|| *cap.ty.clone()),
345+
_ => *cap.ty.clone(),
346+
},
347+
ty => panic!("Unsupported argument type: {:?}", ty),
363348
};
349+
insert_lifetime(&mut ty);
350+
ty
351+
}
364352

365-
// Add a lifetime if there is none
366-
if let syn::Type::Reference(ref mut r) = ty {
367-
r.lifetime.get_or_insert(syn::parse_quote! {'p});
353+
/// Insert lifetime `'p` to `PyRef<Self>` or references (e.g., `&PyType`).
354+
fn insert_lifetime(ty: &mut syn::Type) {
355+
fn insert_lifetime_for_path(path: &mut syn::TypePath) {
356+
if let Some(seg) = path.path.segments.last_mut() {
357+
if let syn::PathArguments::AngleBracketed(ref mut args) = seg.arguments {
358+
let mut has_lifetime = false;
359+
for arg in &mut args.args {
360+
match arg {
361+
// Insert `'p` recursively for `Option<PyRef<Self>>` or so.
362+
syn::GenericArgument::Type(ref mut ty) => insert_lifetime(ty),
363+
syn::GenericArgument::Lifetime(_) => has_lifetime = true,
364+
_ => {}
365+
}
366+
}
367+
// Insert lifetime to PyRef (i.e., PyRef<Self> -> PyRef<'p, Self>)
368+
if !has_lifetime && (seg.ident == "PyRef" || seg.ident == "PyRefMut") {
369+
args.args.insert(0, syn::parse_quote! {'p});
370+
}
371+
}
372+
}
368373
}
369374

370-
ty
375+
match ty {
376+
syn::Type::Reference(ref mut r) => {
377+
r.lifetime.get_or_insert(syn::parse_quote! {'p});
378+
insert_lifetime(&mut *r.elem);
379+
}
380+
syn::Type::Path(ref mut path) => insert_lifetime_for_path(path),
381+
_ => {}
382+
}
371383
}
372384

373385
fn extract_decl(spec: syn::Item) -> syn::Signature {

pyo3-derive-backend/src/pyproto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// Copyright (c) 2017-present PyO3 Project and Contributors
22

33
use crate::defs;
4-
use crate::func::impl_method_proto;
54
use crate::method::{FnSpec, FnType};
5+
use crate::proto_method::impl_method_proto;
66
use crate::pymethod;
77
use proc_macro2::{Span, TokenStream};
88
use quote::quote;

tests/test_dunder.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ struct Iterator {
5151
}
5252

5353
#[pyproto]
54-
impl<'p> PyIterProtocol for Iterator {
55-
fn __iter__(slf: PyRef<'p, Self>) -> Py<Iterator> {
56-
slf.into()
54+
impl PyIterProtocol for Iterator {
55+
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
56+
slf
5757
}
5858

59-
fn __next__(mut slf: PyRefMut<'p, Self>) -> Option<i32> {
59+
fn __next__(mut slf: PyRefMut<Self>) -> Option<i32> {
6060
slf.iter.next()
6161
}
6262
}
@@ -81,7 +81,7 @@ fn iterator() {
8181
struct StringMethods {}
8282

8383
#[pyproto]
84-
impl<'p> PyObjectProtocol<'p> for StringMethods {
84+
impl PyObjectProtocol for StringMethods {
8585
fn __str__(&self) -> &'static str {
8686
"str"
8787
}
@@ -236,7 +236,7 @@ struct SetItem {
236236
}
237237

238238
#[pyproto]
239-
impl PyMappingProtocol<'a> for SetItem {
239+
impl PyMappingProtocol for SetItem {
240240
fn __setitem__(&mut self, key: i32, val: i32) {
241241
self.key = key;
242242
self.val = val;
@@ -362,16 +362,16 @@ struct ContextManager {
362362
}
363363

364364
#[pyproto]
365-
impl<'p> PyContextProtocol<'p> for ContextManager {
365+
impl PyContextProtocol for ContextManager {
366366
fn __enter__(&mut self) -> i32 {
367367
42
368368
}
369369

370370
fn __exit__(
371371
&mut self,
372-
ty: Option<&'p PyType>,
373-
_value: Option<&'p PyAny>,
374-
_traceback: Option<&'p PyAny>,
372+
ty: Option<&PyType>,
373+
_value: Option<&PyAny>,
374+
_traceback: Option<&PyAny>,
375375
) -> bool {
376376
let gil = Python::acquire_gil();
377377
self.exit_called = true;
@@ -564,14 +564,14 @@ impl OnceFuture {
564564

565565
#[pyproto]
566566
impl PyAsyncProtocol for OnceFuture {
567-
fn __await__(slf: PyRef<'p, Self>) -> PyRef<'p, Self> {
567+
fn __await__(slf: PyRef<Self>) -> PyRef<Self> {
568568
slf
569569
}
570570
}
571571

572572
#[pyproto]
573573
impl PyIterProtocol for OnceFuture {
574-
fn __iter__(slf: PyRef<'p, Self>) -> PyRef<'p, Self> {
574+
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
575575
slf
576576
}
577577
fn __next__(mut slf: PyRefMut<Self>) -> Option<PyObject> {
@@ -632,14 +632,14 @@ impl DescrCounter {
632632
#[pyproto]
633633
impl PyDescrProtocol for DescrCounter {
634634
fn __get__(
635-
mut slf: PyRefMut<'p, Self>,
635+
mut slf: PyRefMut<Self>,
636636
_instance: &PyAny,
637-
_owner: Option<&'p PyType>,
638-
) -> PyRefMut<'p, Self> {
637+
_owner: Option<&PyType>,
638+
) -> PyRefMut<Self> {
639639
slf.count += 1;
640640
slf
641641
}
642-
fn __set__(_slf: PyRef<'p, Self>, _instance: &PyAny, mut new_value: PyRefMut<'p, Self>) {
642+
fn __set__(_slf: PyRef<Self>, _instance: &PyAny, mut new_value: PyRefMut<Self>) {
643643
new_value.count = _slf.count;
644644
}
645645
}

0 commit comments

Comments
 (0)