Skip to content

Commit f9363f0

Browse files
committed
rt: Fix comparison of interior vectors
1 parent dcc9a81 commit f9363f0

File tree

1 file changed

+63
-5
lines changed

1 file changed

+63
-5
lines changed

src/rt/rust_shape.cpp

+63-5
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,17 @@ align_to(T size, size_t alignment) {
7777

7878
template<typename T>
7979
static inline T
80-
bump_dp(uint8_t *dp) {
80+
bump_dp(uint8_t *&dp) {
8181
T x = *((T *)dp);
8282
dp += sizeof(T);
8383
return x;
8484
}
8585

86+
template<typename T>
87+
static inline T
88+
get_dp(uint8_t *dp) {
89+
return *((T *)dp);
90+
}
8691

8792
// Utility classes
8893

@@ -145,6 +150,11 @@ class data_pair {
145150
data_pair(T &in_fst, T &in_snd) : fst(in_fst), snd(in_snd) {}
146151

147152
inline void operator=(const T rhs) { fst = snd = rhs; }
153+
154+
static data_pair<T> make(T &fst, T &snd) {
155+
data_pair<T> data(fst, snd);
156+
return data;
157+
}
148158
};
149159

150160
class ptr_pair {
@@ -201,6 +211,14 @@ bump_dp(ptr_pair &ptr) {
201211
return data;
202212
}
203213

214+
template<typename T>
215+
inline data_pair<T>
216+
get_dp(ptr_pair &ptr) {
217+
data_pair<T> data(*reinterpret_cast<T *>(ptr.fst),
218+
*reinterpret_cast<T *>(ptr.snd));
219+
return data;
220+
}
221+
204222

205223
// Contexts
206224

@@ -231,6 +249,7 @@ class ctxt {
231249
task(other.task) {}
232250

233251
void walk(bool align);
252+
void walk_reset(bool align);
234253

235254
std::pair<const uint8_t *,const uint8_t *>
236255
get_variant_sp(tag_info &info, uint32_t variant_id);
@@ -347,6 +366,8 @@ struct type_param {
347366
template<typename T>
348367
void
349368
ctxt<T>::walk(bool align) {
369+
fprintf(stderr, "walking %d\n", *sp);
370+
350371
switch (*sp++) {
351372
case SHAPE_U8: WALK_NUMBER(uint8_t); break;
352373
case SHAPE_U16: WALK_NUMBER(uint16_t); break;
@@ -374,6 +395,14 @@ ctxt<T>::walk(bool align) {
374395
}
375396
}
376397

398+
template<typename T>
399+
void
400+
ctxt<T>::walk_reset(bool align) {
401+
const uint8_t *old_sp = sp;
402+
walk(align);
403+
sp = old_sp;
404+
}
405+
377406
template<typename T>
378407
uint16_t
379408
ctxt<T>::get_u16(const uint8_t *addr) {
@@ -816,8 +845,9 @@ size_of::walk_ivec(bool align, bool is_pod, size_align &elem_sa) {
816845

817846
#define DATA_SIMPLE(ty, call) \
818847
if (align) dp = align_to(dp, sizeof(ty)); \
848+
U end_dp = dp + sizeof(ty); \
819849
static_cast<T *>(this)->call; \
820-
dp += sizeof(ty);
850+
dp = end_dp;
821851

822852
template<typename T,typename U>
823853
class data : public ctxt< data<T,U> > {
@@ -894,6 +924,7 @@ std::pair<uint8_t *,uint8_t *>
894924
data<T,U>::get_ivec_data_range(uint8_t *dp) {
895925
size_t fill = bump_dp<size_t>(dp);
896926
bump_dp<size_t>(dp); // Skip over alloc.
927+
uint8_t *payload_dp = dp;
897928
rust_ivec_payload payload = bump_dp<rust_ivec_payload>(dp);
898929

899930
uint8_t *start, *end;
@@ -906,7 +937,7 @@ data<T,U>::get_ivec_data_range(uint8_t *dp) {
906937
end = start + fill;
907938
}
908939
} else { // On stack.
909-
start = payload.data;
940+
start = payload_dp;
910941
end = start + fill;
911942
}
912943

@@ -916,6 +947,7 @@ data<T,U>::get_ivec_data_range(uint8_t *dp) {
916947
template<typename T,typename U>
917948
std::pair<ptr_pair,ptr_pair>
918949
data<T,U>::get_ivec_data_range(ptr_pair &dp) {
950+
fprintf(stderr, "get_ivec_data_range %p/%p\n", dp.fst, dp.snd);
919951
std::pair<uint8_t *,uint8_t *> fst = get_ivec_data_range(dp.fst);
920952
std::pair<uint8_t *,uint8_t *> snd = get_ivec_data_range(dp.snd);
921953
ptr_pair start(fst.first, snd.first);
@@ -1021,19 +1053,43 @@ class cmp : public data<cmp,ptr_pair> {
10211053
variant_ptr_and_end);
10221054

10231055
template<typename T>
1024-
void walk_number() { cmp_number(bump_dp<T>(dp)); }
1056+
void walk_number() { cmp_number(get_dp<T>(dp)); }
10251057
};
10261058

1059+
template<>
1060+
void cmp::cmp_number<int32_t>(const data_pair<int32_t> &nums) {
1061+
fprintf(stderr, "cmp %d/%d\n", nums.fst, nums.snd);
1062+
result = (nums.fst < nums.snd) ? -1 : (nums.fst == nums.snd) ? 0 : 1;
1063+
}
1064+
10271065
void
10281066
cmp::walk_ivec(bool align, bool is_pod, size_align &elem_sa) {
10291067
std::pair<ptr_pair,ptr_pair> data_range = get_ivec_data_range(dp);
10301068

1069+
DPRINT("walk_ivec %p/%p\n", data_range.first.fst, data_range.first.snd);
1070+
10311071
cmp sub(*this, data_range.first);
10321072
ptr_pair data_end = data_range.second;
10331073
while (!result && sub.dp < data_end) {
1034-
sub.walk(align);
1074+
DPRINT("walk_ivec elem %p/%p %p/%p\n", sub.dp.fst, sub.dp.snd,
1075+
data_end.fst, data_end.snd);
1076+
DPRINTCX(&sub);
1077+
DPRINT("\nend\n");
1078+
1079+
sub.walk_reset(align);
1080+
DPRINT("result = %d\n", sub.result);
10351081
result = sub.result;
10361082
align = true;
1083+
1084+
DPRINT("walk_ivec after elem %p/%p %p/%p\n", sub.dp.fst, sub.dp.snd,
1085+
data_end.fst, data_end.snd);
1086+
}
1087+
1088+
if (!result) {
1089+
// If we hit the end, the result comes down to length comparison.
1090+
int len_fst = data_range.second.fst - data_range.first.fst;
1091+
int len_snd = data_range.second.snd - data_range.first.snd;
1092+
cmp_number(data_pair<int>::make(len_fst, len_snd));
10371093
}
10381094
}
10391095

@@ -1090,6 +1146,8 @@ extern "C" void
10901146
upcall_cmp_type(int8_t *result, rust_task *task, type_desc *tydesc,
10911147
const type_desc **subtydescs, uint8_t *data_0,
10921148
uint8_t *data_1, uint8_t cmp_type) {
1149+
fprintf(stderr, "cmp_type\n");
1150+
10931151
shape::arena arena;
10941152
shape::type_param *params = shape::type_param::make(tydesc, arena);
10951153
shape::cmp cmp(task, tydesc->shape, params, tydesc->shape_tables, data_0,

0 commit comments

Comments
 (0)