diff --git a/awkward/array/chunked.py b/awkward/array/chunked.py index daa9928a..89588b6c 100644 --- a/awkward/array/chunked.py +++ b/awkward/array/chunked.py @@ -291,13 +291,6 @@ def __getitem__(self, where): self._valid() if self._util_isstringslice(where): - if isinstance(where, awkward.util.string): - if not self.type.hascolumn(where): - raise ValueError("no column named {0}".format(repr(where))) - else: - for x in where: - if not self.type.hascolumn(x): - raise ValueError("no column named {0}".format(repr(x))) chunks = [] counts = [] for chunk in self._chunks: @@ -480,8 +473,9 @@ def _aligned(self, what): def __setitem__(self, where, what): if isinstance(what, ChunkedArray) and self._aligned(what): - for mine, theirs in zip(self._chunks, what._chunks): + for i, (mine, theirs) in enumerate(zip(self._chunks, what._chunks)): mine[where] = theirs + self._types[i] = mine.type.to else: raise ValueError("only ChunkedArrays with the same chunk sizes can be assigned to columns of a ChunkedArray") diff --git a/awkward/type.py b/awkward/type.py index b37ce0ff..8783005e 100644 --- a/awkward/type.py +++ b/awkward/type.py @@ -189,6 +189,17 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @staticmethod + def _eq2(one, two, seen, ignoremask=False): + if isinstance(one, Type): + return one._eq(two, seen, ignoremask=ignoremask) + elif one == two: + return True + elif callable(one) and callable(two): + return True + else: + return False + @staticmethod def _finaltype(x): if isinstance(x, type) and issubclass(x, (numbers.Number, numpy.generic)): @@ -287,8 +298,10 @@ def _hascolumn(self, name, seen): return False elif isinstance(self._to, numpy.dtype): return name in self._to.names - else: + elif isinstance(self._to, Type): return self._to._hascolumn(name, seen) + else: + return False def _subrepr(self, labeled, seen): if isinstance(self._to, Type): @@ -315,10 +328,7 @@ def _eq(self, other, seen, ignoremask=False): else: seen.add(id(self)) if isinstance(other, ArrayType) and self._takes == other._takes: - if isinstance(self._to, Type): - return self._to._eq(other._to, seen, ignoremask=ignoremask) - else: - return self._to == other._to + return self._eq2(self._to, other._to, seen, ignoremask=ignoremask) else: return False @@ -410,12 +420,8 @@ def _eq(self, other, seen, ignoremask=False): seen.add(id(self)) if isinstance(other, TableType) and sorted(self._fields) == sorted(other._fields): for n in self._fields: - if isinstance(self._fields[n], Type): - if not self._fields[n]._eq(other._fields[n], seen, ignoremask=ignoremask): - return False - else: - if not self._fields[n] == other._fields[n]: - return False + if not self._eq2(self._fields[n], other._fields[n], seen, ignoremask=ignoremask): + return False else: return True # nothing failed in the loop over fields else: @@ -445,7 +451,13 @@ def _hascolumn(self, name, seen): if id(self) in seen: return False seen.add(id(self)) - return any(x._to._hascolumn(name, seen) for x in self._possibilities) + for x in self._possibilities: + if isinstance(x, numpy.dtype) and x.names is not None and name in x.names: + return True + elif isinstance(x, Type) and x._hascolumn(name, seen): + return True + else: + return False def __len__(self): return len(self._possibilities) @@ -491,12 +503,8 @@ def _eq(self, other, seen, ignoremask=False): seen.add(id(self)) if isinstance(other, UnionType) and len(self._possibilities) == len(other._possibilities): for x, y in zip(sorted(self._possibilities), sorted(self._possibilities)): - if isinstance(x, Type): - if not x._eq(y, seen, ignoremask=ignoremask): - return False - else: - if not x == y: - return False + if not self._eq2(x, y, seen, ignoremask=ignoremask): + return False else: return True # nothing failed in the loop over possibilities else: @@ -571,16 +579,10 @@ def _eq(self, other, seen, ignoremask=False): else: seen.add(id(self)) if isinstance(other, OptionType): - if isinstance(self._type, Type) and self._type._eq(other._type, seen, ignoremask=ignoremask): - return True - elif not isinstance(self._type, Type) and self._type == other._type: + if self._eq2(self._type, other._type, seen, ignoremask=ignoremask): return True - if ignoremask: # applied asymmetrically; only the left can ignore mask - if isinstance(self._type, Type): - return self._type._eq(other, seen, ignoremask=ignoremask) - else: - return self._type == other + return self._eq2(self._type, other, seen, ignoremask=ignoremask) else: return False diff --git a/awkward/version.py b/awkward/version.py index 3e62626c..153257c6 100644 --- a/awkward/version.py +++ b/awkward/version.py @@ -4,7 +4,7 @@ import re -__version__ = "0.10.2" +__version__ = "0.10.3" version = __version__ version_info = tuple(re.split(r"[-\.]", __version__))