Skip to content

Commit

Permalink
Fix import loop with Pydantic (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienVannson authored Oct 15, 2024
1 parent c2bcd31 commit 6a3bbe3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 20 deletions.
8 changes: 4 additions & 4 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class OutputTemplate:
parent_request: PluginRequestCompiler
package_proto_obj: FileDescriptorProto
input_files: List[str] = field(default_factory=list)
imports: Set[str] = field(default_factory=set)
imports_end: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
Expand Down Expand Up @@ -532,7 +532,7 @@ def py_type(self) -> str:
# Type referencing another defined Message or a named enum
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
imports=self.output_file.imports_end,
source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
pydantic=self.output_file.pydantic_dataclasses,
Expand Down Expand Up @@ -730,7 +730,7 @@ def py_input_message_type(self) -> str:
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
imports=self.output_file.imports_end,
source_type=self.proto_obj.input_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
Expand Down Expand Up @@ -760,7 +760,7 @@ def py_output_message_type(self) -> str:
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
imports=self.output_file.imports_end,
source_type=self.proto_obj.output_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
Expand Down
5 changes: 0 additions & 5 deletions src/betterproto/templates/header.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {{ i }}

{% if output_file.pydantic_dataclasses %}
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import rebuild_dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
Expand All @@ -35,10 +34,6 @@ from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}

{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}

{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING

Expand Down
16 changes: 6 additions & 10 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}"
{%- endif -%}
,
*
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
{% if method.comment %}
{{ method.comment }}

Expand Down Expand Up @@ -143,6 +143,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endfor %}
{% endfor %}

{% for i in output_file.imports_end %}
{{ i }}
{% endfor %}

{% for service in output_file.services %}
class {{ service.py_name }}Base(ServiceBase):
{% if service.comment %}
Expand Down Expand Up @@ -211,11 +215,3 @@ class {{ service.py_name }}Base(ServiceBase):
}

{% endfor %}

{% if output_file.pydantic_dataclasses %}
{% for message in output_file.messages %}
{% if message.has_message_field %}
rebuild_dataclass({{ message.py_name }}) # type: ignore
{% endif %}
{% endfor %}
{% endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ import "other.proto";
// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
message Test {
RootPackageMessage message = 1;
other.OtherPackageMessage other = 2;
other.OtherPackageMessage other_value = 2;
}

0 comments on commit 6a3bbe3

Please sign in to comment.