Skip to content

Commit 4aaef75

Browse files
simon0-ofishy
authored andcommitted
THRIFT-5337 Go set fields write improvement
Client: go There is a duplicate elements check for set in writeFields* function, and it compares elements using reflect.DeepEqual which is expensive. It's much faster that generates a *Equals* function for set elements and call it in duplicate elements check, especially for nested struct element. Closes #2307.
1 parent 93d2099 commit 4aaef75

File tree

4 files changed

+549
-9
lines changed

4 files changed

+549
-9
lines changed

compiler/cpp/src/thrift/generate/t_go_generator.cc

+192-7
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class t_go_generator : public t_generator {
152152
const string& tstruct_name,
153153
bool is_result = false,
154154
bool uses_countsetfields = false);
155+
void generate_go_struct_equals(std::ostream& out, t_struct* tstruct, const string& tstruct_name);
155156
void generate_go_function_helpers(t_function* tfunction);
156157
void get_publicized_name_and_def_value(t_field* tfield,
157158
string* OUT_pub_name,
@@ -229,6 +230,12 @@ class t_go_generator : public t_generator {
229230

230231
void generate_serialize_list_element(std::ostream& out, t_list* tlist, std::string iter);
231232

233+
void generate_go_equals(std::ostream& out, t_type* ttype, string tgt, string src);
234+
235+
void generate_go_equals_struct(std::ostream& out, t_type* ttype, string tgt, string src);
236+
237+
void generate_go_equals_container(std::ostream& out, t_type* ttype, string tgt, string src);
238+
232239
void generate_go_docstring(std::ostream& out, t_struct* tstruct);
233240

234241
void generate_go_docstring(std::ostream& out, t_function* tfunction);
@@ -307,6 +314,7 @@ class t_go_generator : public t_generator {
307314
std::set<std::string> package_identifiers_set_;
308315
std::string read_method_name_;
309316
std::string write_method_name_;
317+
std::string equals_method_name_;
310318

311319
std::set<std::string> commonInitialisms;
312320

@@ -724,6 +732,7 @@ void t_go_generator::init_generator() {
724732
read_method_name_ = "Read";
725733
write_method_name_ = "Write";
726734
}
735+
equals_method_name_ = "Equals";
727736

728737
while (true) {
729738
// TODO: Do better error checking here.
@@ -912,7 +921,6 @@ string t_go_generator::go_imports_begin(bool consts) {
912921
std::vector<string> system_packages;
913922
system_packages.push_back("bytes");
914923
system_packages.push_back("context");
915-
system_packages.push_back("reflect");
916924
// If not writing constants, and there are enums, need extra imports.
917925
if (!consts && get_program()->get_enums().size() > 0) {
918926
system_packages.push_back("database/sql/driver");
@@ -937,7 +945,6 @@ string t_go_generator::go_imports_end() {
937945
"var _ = thrift.ZERO\n"
938946
"var _ = fmt.Printf\n"
939947
"var _ = context.Background\n"
940-
"var _ = reflect.DeepEqual\n"
941948
"var _ = time.Now\n"
942949
"var _ = bytes.Equal\n\n");
943950
}
@@ -1482,6 +1489,9 @@ void t_go_generator::generate_go_struct_definition(ostream& out,
14821489
generate_isset_helpers(out, tstruct, tstruct_name, is_result);
14831490
generate_go_struct_reader(out, tstruct, tstruct_name, is_result);
14841491
generate_go_struct_writer(out, tstruct, tstruct_name, is_result, num_setable > 0);
1492+
if (!is_result && !is_args) {
1493+
generate_go_struct_equals(out, tstruct, tstruct_name);
1494+
}
14851495

14861496
out << indent() << "func (p *" << tstruct_name << ") String() string {" << endl;
14871497
out << indent() << " if p == nil {" << endl;
@@ -1851,6 +1861,61 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
18511861
}
18521862
}
18531863

1864+
void t_go_generator::generate_go_struct_equals(ostream& out,
1865+
t_struct* tstruct,
1866+
const string& tstruct_name) {
1867+
string name(tstruct->get_name());
1868+
const vector<t_field*>& fields = tstruct->get_sorted_members();
1869+
vector<t_field*>::const_iterator f_iter;
1870+
indent(out) << "func (p *" << tstruct_name << ") " << equals_method_name_ << "(other *"
1871+
<< tstruct_name << ") bool {" << endl;
1872+
indent_up();
1873+
1874+
string field_name;
1875+
string publicize_field_name;
1876+
out << indent() << "if p == other {" << endl;
1877+
indent_up();
1878+
out << indent() << "return true" << endl;
1879+
indent_down();
1880+
out << indent() << "} else if p == nil || other == nil {" << endl;
1881+
indent_up();
1882+
out << indent() << "return false" << endl;
1883+
indent_down();
1884+
out << indent() << "}" << endl;
1885+
1886+
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
1887+
field_name = (*f_iter)->get_name();
1888+
t_type* field_type = (*f_iter)->get_type();
1889+
publicize_field_name = publicize(field_name);
1890+
string goType = type_to_go_type_with_opt(field_type, is_pointer_field(*f_iter));
1891+
1892+
string tgt = "p." + publicize_field_name;
1893+
string src = "other." + publicize_field_name;
1894+
t_type* ttype = field_type->get_true_type();
1895+
// Compare field contents
1896+
if (is_pointer_field(*f_iter)
1897+
&& (ttype->is_base_type() || ttype->is_enum() || ttype->is_container())) {
1898+
string tgtv = "(*" + tgt + ")";
1899+
string srcv = "(*" + src + ")";
1900+
out << indent() << "if " << tgt << " != " << src << " {" << endl;
1901+
indent_up();
1902+
out << indent() << "if " << tgt << " == nil || " << src << " == nil {" << endl;
1903+
indent_up();
1904+
out << indent() << "return false" << endl;
1905+
indent_down();
1906+
out << indent() << "}" << endl;
1907+
generate_go_equals(out, field_type, tgtv, srcv);
1908+
indent_down();
1909+
out << indent() << "}" << endl;
1910+
} else {
1911+
generate_go_equals(out, field_type, tgt, src);
1912+
}
1913+
}
1914+
out << indent() << "return true" << endl;
1915+
indent_down();
1916+
out << indent() << "}" << endl << endl;
1917+
}
1918+
18541919
/**
18551920
* Generates a thrift service.
18561921
*
@@ -3389,15 +3454,30 @@ void t_go_generator::generate_serialize_container(ostream& out,
33893454
} else if (ttype->is_set()) {
33903455
t_set* tset = (t_set*)ttype;
33913456
out << indent() << "for i := 0; i<len(" << prefix << "); i++ {" << endl;
3392-
out << indent() << " for j := i+1; j<len(" << prefix << "); j++ {" << endl;
3457+
indent_up();
3458+
out << indent() << "for j := i+1; j<len(" << prefix << "); j++ {" << endl;
3459+
indent_up();
33933460
string wrapped_prefix = prefix;
33943461
if (pointer_field) {
33953462
wrapped_prefix = "(" + prefix + ")";
33963463
}
3397-
out << indent() << " if reflect.DeepEqual(" << wrapped_prefix << "[i]," << wrapped_prefix << "[j]) { " << endl;
3398-
out << indent() << " return thrift.PrependError(\"\", fmt.Errorf(\"%T error writing set field: slice is not unique\", " << wrapped_prefix << "[i]))" << endl;
3399-
out << indent() << " }" << endl;
3400-
out << indent() << " }" << endl;
3464+
string goType = type_to_go_type(tset->get_elem_type());
3465+
out << indent() << "if func(tgt, src " << goType << ") bool {" << endl;
3466+
indent_up();
3467+
generate_go_equals(out, tset->get_elem_type(), "tgt", "src");
3468+
out << indent() << "return true" << endl;
3469+
indent_down();
3470+
out << indent() << "}(" << wrapped_prefix << "[i], " << wrapped_prefix << "[j]) {" << endl;
3471+
indent_up();
3472+
out << indent()
3473+
<< "return thrift.PrependError(\"\", fmt.Errorf(\"%T error writing set field: slice is not "
3474+
"unique\", "
3475+
<< wrapped_prefix << "))" << endl;
3476+
indent_down();
3477+
out << indent() << "}" << endl;
3478+
indent_down();
3479+
out << indent() << "}" << endl;
3480+
indent_down();
34013481
out << indent() << "}" << endl;
34023482
out << indent() << "for _, v := range " << prefix << " {" << endl;
34033483
indent_up();
@@ -3463,6 +3543,111 @@ void t_go_generator::generate_serialize_list_element(ostream& out, t_list* tlist
34633543
generate_serialize_field(out, &efield, prefix);
34643544
}
34653545

3546+
/**
3547+
* Compares any type
3548+
*/
3549+
void t_go_generator::generate_go_equals(ostream& out, t_type* ori_type, string tgt, string src) {
3550+
3551+
t_type* ttype = get_true_type(ori_type);
3552+
// Do nothing for void types
3553+
if (ttype->is_void()) {
3554+
throw "compiler error: cannot generate equals for void type: " + tgt;
3555+
}
3556+
3557+
if (ttype->is_struct() || ttype->is_xception()) {
3558+
generate_go_equals_struct(out, ttype, tgt, src);
3559+
} else if (ttype->is_container()) {
3560+
generate_go_equals_container(out, ttype, tgt, src);
3561+
} else if (ttype->is_base_type() || ttype->is_enum()) {
3562+
out << indent() << "if ";
3563+
if (ttype->is_base_type()) {
3564+
t_base_type::t_base tbase = ((t_base_type*)ttype)->get_base();
3565+
switch (tbase) {
3566+
case t_base_type::TYPE_VOID:
3567+
throw "compiler error: cannot equals void: " + tgt;
3568+
break;
3569+
3570+
case t_base_type::TYPE_STRING:
3571+
if (ttype->is_binary()) {
3572+
out << "bytes.Compare(" << tgt << ", " << src << ") != 0";
3573+
} else {
3574+
out << tgt << " != " << src;
3575+
}
3576+
break;
3577+
3578+
case t_base_type::TYPE_BOOL:
3579+
case t_base_type::TYPE_I8:
3580+
case t_base_type::TYPE_I16:
3581+
case t_base_type::TYPE_I32:
3582+
case t_base_type::TYPE_I64:
3583+
case t_base_type::TYPE_DOUBLE:
3584+
out << tgt << " != " << src;
3585+
break;
3586+
3587+
default:
3588+
throw "compiler error: no Go name for base type " + t_base_type::t_base_name(tbase);
3589+
}
3590+
} else if (ttype->is_enum()) {
3591+
out << tgt << " != " << src;
3592+
}
3593+
3594+
out << " { return false }" << endl;
3595+
} else {
3596+
throw "compiler error: Invalid type in generate_go_equals '" + ttype->get_name() + "' for '"
3597+
+ tgt + "'";
3598+
}
3599+
}
3600+
3601+
/**
3602+
* Compares the members of a struct
3603+
*/
3604+
void t_go_generator::generate_go_equals_struct(ostream& out,
3605+
t_type* ttype,
3606+
string tgt,
3607+
string src) {
3608+
(void)ttype;
3609+
out << indent() << "if !" << tgt << "." << equals_method_name_ << "(" << src
3610+
<< ") { return false }" << endl;
3611+
}
3612+
3613+
/**
3614+
* Compares any container type
3615+
*/
3616+
void t_go_generator::generate_go_equals_container(ostream& out,
3617+
t_type* ttype,
3618+
string tgt,
3619+
string src) {
3620+
out << indent() << "if len(" << tgt << ") != len(" << src << ") { return false }" << endl;
3621+
if (ttype->is_map()) {
3622+
t_map* tmap = (t_map*)ttype;
3623+
out << indent() << "for k, _tgt := range " << tgt << " {" << endl;
3624+
indent_up();
3625+
string element_source = tmp("_src");
3626+
out << indent() << element_source << " := " << src << "[k]" << endl;
3627+
generate_go_equals(out, tmap->get_val_type(), "_tgt", element_source);
3628+
indent_down();
3629+
indent(out) << "}" << endl;
3630+
} else if (ttype->is_list() || ttype->is_set()) {
3631+
t_type* elem;
3632+
if (ttype->is_list()) {
3633+
t_list* temp = (t_list*)ttype;
3634+
elem = temp->get_elem_type();
3635+
} else {
3636+
t_set* temp = (t_set*)ttype;
3637+
elem = temp->get_elem_type();
3638+
}
3639+
out << indent() << "for i, _tgt := range " << tgt << " {" << endl;
3640+
indent_up();
3641+
string element_source = tmp("_src");
3642+
out << indent() << element_source << " := " << src << "[i]" << endl;
3643+
generate_go_equals(out, elem, "_tgt", element_source);
3644+
indent_down();
3645+
indent(out) << "}" << endl;
3646+
} else {
3647+
throw "INVALID TYPE IN generate_go_equals_container '" + ttype->get_name();
3648+
}
3649+
}
3650+
34663651
/**
34673652
* Generates the docstring for a given struct.
34683653
*/

lib/go/test/EqualsTest.thrift

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
typedef i8 mybyte
2+
typedef string mystr
3+
typedef binary mybin
4+
5+
enum EnumFoo {
6+
e1
7+
e2
8+
}
9+
10+
struct BasicEqualsFoo {
11+
1: bool BoolFoo,
12+
2: optional bool OptBoolFoo,
13+
3: i8 I8Foo,
14+
4: optional i8 OptI8Foo,
15+
5: i16 I16Foo,
16+
6: optional i16 OptI16Foo,
17+
7: i32 I32Foo,
18+
8: optional i32 OptI32Foo,
19+
9: i64 I64Foo,
20+
10: optional i64 OptI64Foo,
21+
11: double DoubleFoo,
22+
12: optional double OptDoubleFoo,
23+
13: string StrFoo,
24+
14: optional string OptStrFoo,
25+
15: binary BinFoo,
26+
16: optional binary OptBinFoo,
27+
17: EnumFoo EnumFoo,
28+
18: optional EnumFoo OptEnumFoo,
29+
19: mybyte MyByteFoo,
30+
20: optional mybyte OptMyByteFoo,
31+
21: mystr MyStrFoo,
32+
22: optional mystr OptMyStrFoo,
33+
23: mybin MyBinFoo,
34+
24: optional mybin OptMyBinFoo,
35+
}
36+
37+
struct StructEqualsFoo {
38+
1: BasicEqualsFoo StructFoo,
39+
2: optional BasicEqualsFoo OptStructFoo,
40+
}
41+
42+
struct ListEqualsFoo {
43+
1: list<i64> I64ListFoo,
44+
2: optional list<i64> OptI64ListFoo,
45+
3: list<string> StrListFoo,
46+
4: optional list<string> OptStrListFoo,
47+
5: list<binary> BinListFoo,
48+
6: optional list<binary> OptBinListFoo,
49+
7: list<BasicEqualsFoo> StructListFoo,
50+
8: optional list<BasicEqualsFoo> OptStructListFoo,
51+
9: list<list<i64>> I64ListListFoo,
52+
10: optional list<list<i64>> OptI64ListListFoo,
53+
11: list<set<i64>> I64SetListFoo,
54+
12: optional list<set<i64>> OptI64SetListFoo,
55+
13: list<map<i64, string>> I64StringMapListFoo,
56+
14: optional list<map<i64, string>> OptI64StringMapListFoo,
57+
15: list<mybyte> MyByteListFoo,
58+
16: optional list<mybyte> OptMyByteListFoo,
59+
17: list<mystr> MyStrListFoo,
60+
18: optional list<mystr> OptMyStrListFoo,
61+
19: list<mybin> MyBinListFoo,
62+
20: optional list<mybin> OptMyBinListFoo,
63+
}
64+
65+
struct SetEqualsFoo {
66+
1: set<i64> I64SetFoo,
67+
2: optional set<i64> OptI64SetFoo,
68+
3: set<string> StrSetFoo,
69+
4: optional set<string> OptStrSetFoo,
70+
5: set<binary> BinSetFoo,
71+
6: optional set<binary> OptBinSetFoo,
72+
7: set<BasicEqualsFoo> StructSetFoo,
73+
8: optional set<BasicEqualsFoo> OptStructSetFoo,
74+
9: set<list<i64>> I64ListSetFoo,
75+
10: optional set<list<i64>> OptI64ListSetFoo,
76+
11: set<set<i64>> I64SetSetFoo,
77+
12: optional set<set<i64>> OptI64SetSetFoo,
78+
13: set<map<i64, string>> I64StringMapSetFoo,
79+
14: optional set<map<i64, string>> OptI64StringMapSetFoo,
80+
15: set<mybyte> MyByteSetFoo,
81+
16: optional set<mybyte> OptMyByteSetFoo,
82+
17: set<mystr> MyStrSetFoo,
83+
18: optional set<mystr> OptMyStrSetFoo,
84+
19: set<mybin> MyBinSetFoo,
85+
20: optional set<mybin> OptMyBinSetFoo,
86+
}
87+
88+
struct MapEqualsFoo {
89+
1: map<i64, string> I64StrMapFoo,
90+
2: optional map<i64, string> OptI64StrMapFoo,
91+
3: map<string, i64> StrI64MapFoo,
92+
4: optional map<string, i64> OptStrI64MapFoo,
93+
5: map<BasicEqualsFoo, binary> StructBinMapFoo,
94+
6: optional map<BasicEqualsFoo, binary> OptStructBinMapFoo,
95+
7: map<binary, BasicEqualsFoo> BinStructMapFoo,
96+
8: optional map<binary, BasicEqualsFoo> OptBinStructMapFoo,
97+
9: map<i64, list<i64>> I64I64ListMapFoo,
98+
10: optional map<i64, list<i64>> OptI64I64ListMapFoo,
99+
11: map<i64, set<i64>> I64I64SetMapFoo,
100+
12: optional map<i64, set<i64>> OptI64I64SetMapFoo,
101+
13: map<i64, map<i64, string>> I64I64StringMapMapFoo,
102+
14: optional map<i64, map<i64, string>> OptI64I64StringMapMapFoo,
103+
15: map<mystr, mybin> MyStrMyBinMapFoo,
104+
16: optional map<mystr, mybin> OptMyStrMyBinMapFoo,
105+
17: map<i64, mybyte> Int64MyByteMapFoo,
106+
18: optional map<i64, mybyte> OptInt64MyByteMapFoo,
107+
19: map<mybyte, i64> MyByteInt64MapFoo,
108+
20: optional map<mybyte, i64> OptMyByteInt64MapFoo,
109+
}

0 commit comments

Comments
 (0)