Commit b1db8d84 authored by Georg Schmid's avatar Georg Schmid Committed by Commit Bot

[torque] Infer type arguments of generic struct initializers

Previously when creating a new generic struct, one had to explicitly provide all type arguments, e.g., for the generic struct

  struct Box<T: type> {
    const value: T;
  }

one would initialize a new box using

  const aSmi: Smi = ...;
  const box = Box<Smi> { value: aSmi };

With the additions in this CL the explicit type argument can be omitted. Type inference proceeds analogously to specialization of generic callables.

Additionally, this CL slightly refactors class and struct initialization, and make type inference more permissive in the presence of unsupported type constructors (concretely, union types and function types).

R=jgruber@chromium.org, tebbi@chromium.org

Change-Id: I529be5831a85d317d8caa6cb3a0ce398ad578c86
Bug: v8:7793
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/1728617
Commit-Queue: Georg Schmid <gsps@google.com>
Reviewed-by: 's avatarTobias Tebbi <tebbi@chromium.org>
Cr-Commit-Position: refs/heads/master@{#63036}
parent f2f0562f
......@@ -139,6 +139,13 @@ GenericStructType* Declarations::LookupUniqueGenericStructType(
"generic struct");
}
base::Optional<GenericStructType*> Declarations::TryLookupGenericStructType(
const QualifiedName& name) {
std::vector<GenericStructType*> results = TryLookup<GenericStructType>(name);
if (results.empty()) return base::nullopt;
return EnsureUnique(results, name.name, "generic struct");
}
Namespace* Declarations::DeclareNamespace(const std::string& name) {
return Declare(name, std::unique_ptr<Namespace>(new Namespace(name)));
}
......
......@@ -78,6 +78,8 @@ class Declarations {
static GenericStructType* LookupUniqueGenericStructType(
const QualifiedName& name);
static base::Optional<GenericStructType*> TryLookupGenericStructType(
const QualifiedName& name);
static Namespace* DeclareNamespace(const std::string& name);
static TypeAlias* DeclareType(const Identifier* name, const Type* type);
......
......@@ -1111,29 +1111,6 @@ const Type* ImplementationVisitor::Visit(ReturnStatement* stmt) {
return TypeOracle::GetNeverType();
}
VisitResult ImplementationVisitor::TemporaryUninitializedStruct(
const StructType* struct_type, const std::string& reason) {
StackRange range = assembler().TopRange(0);
for (const Field& f : struct_type->fields()) {
if (const StructType* struct_type =
StructType::DynamicCast(f.name_and_type.type)) {
range.Extend(
TemporaryUninitializedStruct(struct_type, reason).stack_range());
} else {
std::string descriptor = "uninitialized field '" + f.name_and_type.name +
"' declared at " + PositionAsString(f.pos) +
" (" + reason + ")";
TypeVector lowered_types = LowerType(f.name_and_type.type);
for (const Type* type : lowered_types) {
assembler().Emit(PushUninitializedInstruction{
TypeOracle::GetTopType(descriptor, type)});
}
range.Extend(assembler().TopRange(lowered_types.size()));
}
}
return VisitResult(struct_type, range);
}
VisitResult ImplementationVisitor::Visit(TryLabelExpression* expr) {
size_t parameter_count = expr->label_block->parameters.names.size();
std::vector<VisitResult> parameters;
......@@ -1212,15 +1189,38 @@ VisitResult ImplementationVisitor::Visit(StatementExpression* expr) {
return VisitResult{Visit(expr->statement), assembler().TopRange(0)};
}
void ImplementationVisitor::CheckInitializersWellformed(
const std::string& aggregate_name,
const std::vector<Field>& aggregate_fields,
const std::vector<NameAndExpression>& initializers,
bool ignore_first_field) {
size_t fields_offset = ignore_first_field ? 1 : 0;
size_t fields_size = aggregate_fields.size() - fields_offset;
for (size_t i = 0; i < std::min(fields_size, initializers.size()); i++) {
const std::string& field_name =
aggregate_fields[i + fields_offset].name_and_type.name;
Identifier* found_name = initializers[i].name;
if (field_name != found_name->value) {
Error("Expected field name \"", field_name, "\" instead of \"",
found_name->value, "\"")
.Position(found_name->pos)
.Throw();
}
}
if (fields_size != initializers.size()) {
ReportError("expected ", fields_size, " initializers for ", aggregate_name,
" found ", initializers.size());
}
}
InitializerResults ImplementationVisitor::VisitInitializerResults(
const AggregateType* current_aggregate,
const ClassType* class_type,
const std::vector<NameAndExpression>& initializers) {
InitializerResults result;
for (const NameAndExpression& initializer : initializers) {
result.names.push_back(initializer.name);
Expression* e = initializer.expression;
const Field& field =
current_aggregate->LookupField(initializer.name->value);
const Field& field = class_type->LookupField(initializer.name->value);
auto field_index = field.index;
if (SpreadExpression* s = SpreadExpression::DynamicCast(e)) {
if (!field_index) {
......@@ -1239,54 +1239,30 @@ InitializerResults ImplementationVisitor::VisitInitializerResults(
return result;
}
size_t ImplementationVisitor::InitializeAggregateHelper(
const AggregateType* aggregate_type, VisitResult allocate_result,
void ImplementationVisitor::InitializeClass(
const ClassType* class_type, VisitResult allocate_result,
const InitializerResults& initializer_results) {
const ClassType* current_class = ClassType::DynamicCast(aggregate_type);
size_t current = 0;
if (current_class) {
const ClassType* super = current_class->GetSuperClass();
if (super) {
current = InitializeAggregateHelper(super, allocate_result,
initializer_results);
}
if (const ClassType* super = class_type->GetSuperClass()) {
InitializeClass(super, allocate_result, initializer_results);
}
for (Field f : aggregate_type->fields()) {
if (current == initializer_results.field_value_map.size()) {
ReportError("insufficient number of initializers for ",
aggregate_type->name());
}
for (Field f : class_type->fields()) {
VisitResult current_value =
initializer_results.field_value_map.at(f.name_and_type.name);
Identifier* fieldname = initializer_results.names[current];
if (fieldname->value != f.name_and_type.name) {
CurrentSourcePosition::Scope scope(fieldname->pos);
ReportError("Expected fieldname \"", f.name_and_type.name,
"\" instead of \"", fieldname->value, "\"");
}
if (aggregate_type->IsClassType()) {
if (f.index) {
InitializeFieldFromSpread(allocate_result, f, initializer_results);
} else {
allocate_result.SetType(aggregate_type);
allocate_result.SetType(class_type);
GenerateCopy(allocate_result);
assembler().Emit(CreateFieldReferenceInstruction{
ClassType::cast(aggregate_type), f.name_and_type.name});
ClassType::cast(class_type), f.name_and_type.name});
VisitResult heap_reference(
TypeOracle::GetReferenceType(f.name_and_type.type),
assembler().TopRange(2));
GenerateAssignToLocation(
LocationReference::HeapReference(heap_reference), current_value);
}
} else {
LocationReference struct_field_ref = LocationReference::VariableAccess(
ProjectStructField(allocate_result, f.name_and_type.name));
GenerateAssignToLocation(struct_field_ref, current_value);
GenerateAssignToLocation(LocationReference::HeapReference(heap_reference),
current_value);
}
++current;
}
return current;
}
void ImplementationVisitor::InitializeFieldFromSpread(
......@@ -1305,17 +1281,6 @@ void ImplementationVisitor::InitializeFieldFromSpread(
{field.aggregate, index.type, iterator.type()});
}
void ImplementationVisitor::InitializeAggregate(
const AggregateType* aggregate_type, VisitResult allocate_result,
const InitializerResults& initializer_results) {
size_t consumed_initializers = InitializeAggregateHelper(
aggregate_type, allocate_result, initializer_results);
if (consumed_initializers != initializer_results.field_value_map.size()) {
ReportError("more initializers than fields present in ",
aggregate_type->name());
}
}
VisitResult ImplementationVisitor::AddVariableObjectSize(
VisitResult object_size, const ClassType* current_class,
const InitializerResults& initializer_results) {
......@@ -1398,6 +1363,11 @@ VisitResult ImplementationVisitor::Visit(NewExpression* expr) {
initializer_results.field_value_map[map_field.name_and_type.name] =
object_map;
}
CheckInitializersWellformed(class_type->name(),
class_type->ComputeAllFields(),
expr->initializers, !class_type->IsExtern());
Arguments size_arguments;
size_arguments.parameters.push_back(object_map);
VisitResult object_size = GenerateCall("%GetAllocationBaseSize",
......@@ -1412,7 +1382,7 @@ VisitResult ImplementationVisitor::Visit(NewExpression* expr) {
GenerateCall("%Allocate", allocate_arguments, {class_type}, false);
DCHECK(allocate_result.IsOnStack());
InitializeAggregate(class_type, allocate_result, initializer_results);
InitializeClass(class_type, allocate_result, initializer_results);
return stack_scope.Yield(allocate_result);
}
......@@ -1802,24 +1772,36 @@ VisitResult ImplementationVisitor::GenerateCopy(const VisitResult& to_copy) {
VisitResult ImplementationVisitor::Visit(StructExpression* expr) {
StackScope stack_scope(this);
const Type* raw_type = TypeVisitor::ComputeType(expr->type);
if (!raw_type->IsStructType()) {
ReportError(*raw_type, " is not a struct but used like one");
}
const StructType* struct_type = StructType::cast(raw_type);
auto& initializers = expr->initializers;
std::vector<VisitResult> values;
std::vector<const Type*> term_argument_types;
values.reserve(initializers.size());
term_argument_types.reserve(initializers.size());
InitializerResults initialization_results =
ImplementationVisitor::VisitInitializerResults(struct_type,
expr->initializers);
// Compute values and types of all initializer arguments
for (const NameAndExpression& initializer : initializers) {
VisitResult value = Visit(initializer.expression);
values.push_back(value);
term_argument_types.push_back(value.type());
}
// Push uninitialized 'this'
VisitResult result = TemporaryUninitializedStruct(
struct_type, "it's not initialized in the struct " + struct_type->name());
// Compute and check struct type from given struct name and argument types
const StructType* struct_type = TypeVisitor::ComputeTypeForStructExpression(
expr->type, term_argument_types);
CheckInitializersWellformed(struct_type->name(), struct_type->fields(),
initializers);
InitializeAggregate(struct_type, result, initialization_results);
// Implicitly convert values and thereby build the struct on the stack
StackRange struct_range = assembler().TopRange(0);
auto& fields = struct_type->fields();
for (size_t i = 0; i < values.size(); i++) {
values[i] =
GenerateImplicitConvert(fields[i].name_and_type.type, values[i]);
struct_range.Extend(values[i].stack_range());
}
return stack_scope.Yield(result);
return stack_scope.Yield(VisitResult(struct_type, struct_range));
}
LocationReference ImplementationVisitor::GetLocationReference(
......
......@@ -365,27 +365,26 @@ class ImplementationVisitor {
VisitResult Visit(Expression* expr);
const Type* Visit(Statement* stmt);
void CheckInitializersWellformed(
const std::string& aggregate_name,
const std::vector<Field>& aggregate_fields,
const std::vector<NameAndExpression>& initializers,
bool ignore_first_field = false);
InitializerResults VisitInitializerResults(
const AggregateType* aggregate,
const ClassType* class_type,
const std::vector<NameAndExpression>& expressions);
void InitializeFieldFromSpread(VisitResult object, const Field& field,
const InitializerResults& initializer_results);
size_t InitializeAggregateHelper(
const AggregateType* aggregate_type, VisitResult allocate_result,
const InitializerResults& initializer_results);
VisitResult AddVariableObjectSize(
VisitResult object_size, const ClassType* current_class,
const InitializerResults& initializer_results);
void InitializeAggregate(const AggregateType* aggregate_type,
VisitResult allocate_result,
void InitializeClass(const ClassType* class_type, VisitResult allocate_result,
const InitializerResults& initializer_results);
VisitResult TemporaryUninitializedStruct(const StructType* struct_type,
const std::string& reason);
VisitResult Visit(StructExpression* decl);
LocationReference GetLocationReference(Expression* location);
......
......@@ -11,7 +11,7 @@ namespace torque {
TypeArgumentInference::TypeArgumentInference(
const NameVector& type_parameters,
const TypeVector& explicit_type_arguments,
const std::vector<TypeExpression*> term_parameters,
const std::vector<TypeExpression*>& term_parameters,
const TypeVector& term_argument_types)
: num_explicit_(explicit_type_arguments.size()),
type_parameter_from_name_(type_parameters.size()),
......@@ -84,7 +84,7 @@ void TypeArgumentInference::Match(TypeExpression* parameter,
// argument types, but we are only interested in inferring type arguments
// here
} else {
Fail("unsupported parameter expression");
// TODO(gsps): Perform inference on function and union types
}
}
......
......@@ -46,12 +46,13 @@ namespace torque {
// Pick<Smi>(1, aSmi); // inference succeeds (doing nothing)
//
// In the above case the inference simply ignores inconsistent constraints on
// `T`.
// `T`. Similarly, we ignore all constraints arising from formal parameters
// that are function- or union-typed.
class TypeArgumentInference {
public:
TypeArgumentInference(const NameVector& type_parameters,
const TypeVector& explicit_type_arguments,
const std::vector<TypeExpression*> term_parameters,
const std::vector<TypeExpression*>& term_parameters,
const TypeVector& term_argument_types);
bool HasFailed() const { return failure_reason_.has_value(); }
......
......@@ -8,6 +8,7 @@
#include "src/torque/declarable.h"
#include "src/torque/global-context.h"
#include "src/torque/server-data.h"
#include "src/torque/type-inference.h"
#include "src/torque/type-oracle.h"
namespace v8 {
......@@ -121,6 +122,10 @@ const StructType* TypeVisitor::ComputeType(
CurrentSourcePosition::Scope position_activator(
field.name_and_type.type->pos);
const Type* field_type = TypeVisitor::ComputeType(field.name_and_type.type);
if (field_type->IsConstexpr()) {
ReportError("struct field \"", field.name_and_type.name->value,
"\" carries constexpr type \"", *field_type, "\"");
}
struct_type->RegisterField({field.name_and_type.name->pos,
struct_type,
base::nullopt,
......@@ -316,6 +321,53 @@ void TypeVisitor::VisitClassFieldsAndMethods(
DeclareMethods(class_type, class_declaration->methods);
}
const StructType* TypeVisitor::ComputeTypeForStructExpression(
TypeExpression* type_expression,
const std::vector<const Type*>& term_argument_types) {
auto* basic = BasicTypeExpression::DynamicCast(type_expression);
if (!basic) {
ReportError("expected basic type expression referring to struct");
}
QualifiedName qualified_name{basic->namespace_qualification, basic->name};
base::Optional<GenericStructType*> maybe_generic_struct =
Declarations::TryLookupGenericStructType(qualified_name);
// Compute types of non-generic structs as usual
if (!maybe_generic_struct) {
const Type* type = ComputeType(type_expression);
const StructType* struct_type = StructType::DynamicCast(type);
if (!struct_type) {
ReportError(*type, " is not a struct, but used like one");
}
return struct_type;
}
auto generic_struct = *maybe_generic_struct;
auto explicit_type_arguments = ComputeTypeVector(basic->generic_arguments);
std::vector<TypeExpression*> term_parameters;
auto& fields = generic_struct->declaration()->fields;
term_parameters.reserve(fields.size());
for (auto& field : fields) {
term_parameters.push_back(field.name_and_type.type);
}
TypeArgumentInference inference(
generic_struct->declaration()->generic_parameters,
explicit_type_arguments, term_parameters, term_argument_types);
if (inference.HasFailed()) {
ReportError("failed to infer type arguments for struct ", basic->name,
" initialization: ", inference.GetFailureReason());
}
if (GlobalContext::collect_language_server_data()) {
LanguageServerData::AddDefinition(type_expression->pos,
generic_struct->declaration()->name->pos);
}
return TypeOracle::GetGenericStructTypeInstance(generic_struct,
inference.GetResult());
}
} // namespace torque
} // namespace internal
} // namespace v8
......@@ -28,6 +28,9 @@ class TypeVisitor {
static void VisitClassFieldsAndMethods(
ClassType* class_type, const ClassDeclaration* class_declaration);
static Signature MakeSignature(const CallableNodeSignature* signature);
static const StructType* ComputeTypeForStructExpression(
TypeExpression* type_expression,
const std::vector<const Type*>& term_argument_types);
private:
friend class TypeAlias;
......
......@@ -418,6 +418,17 @@ void ClassType::Finalize() const {
CheckForDuplicateFields();
}
std::vector<Field> ClassType::ComputeAllFields() const {
std::vector<Field> all_fields;
const ClassType* super_class = this->GetSuperClass();
if (super_class) {
all_fields = super_class->ComputeAllFields();
}
const std::vector<Field>& fields = this->fields();
all_fields.insert(all_fields.end(), fields.begin(), fields.end());
return all_fields;
}
void ClassType::GenerateAccessors() {
// For each field, construct AST snippets that implement a CSA accessor
// function and define a corresponding '.field' operator. The
......
......@@ -550,6 +550,8 @@ class ClassType final : public AggregateType {
}
void Finalize() const override;
std::vector<Field> ComputeAllFields() const;
private:
friend class TypeOracle;
friend class TypeVisitor;
......
......@@ -580,7 +580,7 @@ TEST(TestGenericStruct2) {
i::HandleScope scope(isolate);
CodeAssemblerTester asm_tester(isolate);
TestTorqueAssembler m(asm_tester.state());
{ m.Return(m.TestGenericStruct2().fst); }
{ m.Return(m.TestGenericStruct2().snd.fst); }
FunctionTester ft(asm_tester.GenerateCode(), 0);
ft.Call();
}
......
......@@ -980,8 +980,8 @@ namespace test {
@export
macro TestGenericStruct1(): intptr {
const i: intptr = 123;
let box = SBox<intptr>{value: i};
let boxbox = SBox<SBox<intptr>>{value: box};
let box = SBox{value: i};
let boxbox: SBox<SBox<intptr>> = SBox{value: box};
check(box.value == 123);
boxbox.value.value *= 2;
check(boxbox.value.value == 246);
......@@ -995,16 +995,19 @@ namespace test {
macro TupleSwap<T1: type, T2: type>(tuple: TestTuple<T1, T2>):
TestTuple<T2, T1> {
return TestTuple<T2, T1>{fst: tuple.snd, snd: tuple.fst};
return TestTuple{fst: tuple.snd, snd: tuple.fst};
}
@export
macro TestGenericStruct2(): TestTuple<Smi, intptr> {
macro TestGenericStruct2():
TestTuple<TestTuple<intptr, Smi>, TestTuple<Smi, intptr>> {
const intptrAndSmi = TestTuple<intptr, Smi>{fst: 1, snd: 2};
const smiAndIntptr = TupleSwap(intptrAndSmi);
check(intptrAndSmi.fst == smiAndIntptr.snd);
check(intptrAndSmi.snd == smiAndIntptr.fst);
return smiAndIntptr;
const tupleTuple =
TestTuple<TestTuple<intptr, Smi>>{fst: intptrAndSmi, snd: smiAndIntptr};
return tupleTuple;
}
macro BranchAndWriteResult(x: Smi, box: SmiBox): bool {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment