记录博客 ZH-BLOG

Python 重载运算符

时间:2018-08-17 14:36:13分类:python

一元运算符

-(__neg__)

一元取负算术运算符。 如果 x 是 -2, 那么 -x == 2。

+(__pos__)

一元取正算术运算符。 通常, x == +x, 但也有一些例外。

~(__invert__)

对整数按位取反, 定义为 ~x == -(x+1)。 如果 x 是 2, 那么 ~x == -3。

x 和 +x 可能不相等,是因为 +x 创建了新的实例对象

>>> import decimal
>>> ctx = decimal.getcontext()
>>> ctx.prec = 40
>>> one_third = decimal.Decimal(1) / decimal.Decimal(3)
>>> one_third
Decimal('0.3333333333333333333333333333333333333333')
>>> one_third == +one_third
True
>>> ctx.prec = 28
>>> one_third == +one_third
False
>>> one_third
Decimal('0.3333333333333333333333333333333333333333')
>>> +one_third  # +one_third 会使用 one_third 的值创建一个新的 Decimal 实例
Decimal('0.3333333333333333333333333333')

>>> from collections import Counter
>>> ct = Counter('abcabdfnabdscd')
>>> ct
Counter({'a': 3, 'b': 3, 'd': 3, 'c': 2, 'f': 1, 'n': 1, 's': 1})
>>> ct['s'] = -3
>>> ct['n'] = 0
>>> ct
Counter({'a': 3, 'b': 3, 'd': 3, 'c': 2, 'f': 1, 'n': 0, 's': -3})
>>> +ct  # 创建一个新的 Counter 且仅保留大于零的计数器
Counter({'a': 3, 'b': 3, 'd': 3, 'c': 2, 'f': 1})

Vector 类实现 +v、-v

def __pos__(self):
	return Vector(self)


def __len__(self):
	return len(self._components)

重载加法运算符 +

def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a + b for a, b in pairs)

>>> v1 = Vector([3, 4, 5])
>>> v2 = Vector([6, 7, 8])
>>> v1 + v2
Vector([9.0, 11.0, 13.0])
>>> v3 = Vector([1, 2]) # 长度不一致用 0.0 填充
>>> v1 + v3
Vector([4.0, 6.0, 5.0])
>>> v1 + (10, 20, 30) # 可迭代对象均可
Vector([13.0, 24.0, 35.0])
>>> (10, 20, 30) + v1 # 反向操作不行
Traceback (most recent call last):
...
TypeError: can only concatenate tuple (not "Vector") to tuple

对表达式 a + b 来说, 解释器会执行以下几步操作:

(1) 如果 a 有 __add__ 方法, 而且返回值不是 NotImplemented, 调用 a.__add__(b), 然后返回结果。

(2) 如果 a 没有 __add__ 方法, 或者调用 __add__ 方法返回 NotImplemented, 检查 b 有没有 __radd__ 方法, 如果有, 而且没有返回 NotImplemented, 调用 b.__radd__(a), 然后返回结果。

(3) 如果 b 没有 __radd__ 方法, 或者调用 __radd__ 方法返回 NotImplemented, 抛出 TypeError, 并在错误消息中指明操作数类型不支持。

Vector 实现 __radd__ 方法

def __radd__(self, other):
        return self + other  # 仍然由 __add__ 处理

在错误消息中指明操作数类型不支持

def __add__(self, other):
	try:
		pairs = itertools.zip_longest(self, other, fillvalue=0.0)
		return Vector(a + b for a, b in pairs)
	except TypeError:
		return NotImplemented


def __radd__(self, other):
	return self + other

重载乘法运算符 *

乘法运算符与上面加法类似

def __mul__(self, scalar):
	if isinstance(scalar, numbers.Real):
		return Vector(n * scalar for n in self)
	else:
		return NotImplemented


def __rmul__(self, scalar):
	return self * scalar

>>> v1 = Vector([1, 2, 3])
>>> 14 * v1
Vector([14.0, 28.0, 42.0])
>>> v1 * True
Vector([1.0, 2.0, 3.0])
>>> from fractions import Fraction
>>> v1 * Fraction(1, 3)
Vector([0.3333333333333333, 0.6666666666666666, 1.0])

矩阵乘法 @

def __matmul__(self, other):
	try:
		return sum(a * b for a, b in zip(self, other))
	except TypeError:
		return NotImplemented


def __rmatmul__(self, other):
	return self @ other

>>> v1 = Vector([1, 2, 3])
>>> v2 = Vector([2, 3, 4])
>>> v1 @ v2
20.0

相等运算符 ==

def __eq__(self, other):
	# return len(self) == len(other) and all(a!=b for a,b in zip(self, other))
	# if len(self) != len(other):
	#    return False
	# for a, b in zip(self, other):
	#    if a != b:
	#        return False
	# return True
	if isinstance(other, Vector):
		return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
	else:
		return NotImplemented

>>> Vector = vector.Vector
>>> v1 = Vector([1.0, 2.0, 3.0])
>>> v2 = Vector(range(1, 4))
>>> v1 == v2
True
>>> v3 = Vector([1, 2])
>>> v1 == v3
False
>>> t3 = (1, 2, 3)
>>> v1 == t3
False

Vector 实例和元组比较时, 具体步骤如下:

(1) 为了计算 v1 == t3, Python 调用 Vector.__eq__(v1, t3)。

(2) 经 Vector.__eq__(va, t3) 确认, t3 不是 Vector 实例, 因此返回 NotImplemented。

(3) Python 得到 NotImplemented 结果, 尝试调用 tuple.__eq__(t3, v1)。

(4) tuple.__eq__(t3, v1) 不知道 Vector 是什么, 因此返回 NotImplemented。

(5) 对 == 来说, 如果反向调用返回 NotImplemented, Python 会比较对象的 ID, 作最后一搏。

!= 运算符不用实现它, 因为从 object 继承的 __ne__ 方法的后备行为满足了我们的需求: 定义了 __eq__ 方法, 而且它不返回 NotImplemented, __ne__ 会对 __eq__ 返回的结果取反。

def __ne__(self, other):
	eq_result = self == other
	if eq_result is NotImplemented:
		return NotImplemented
	else:
		return not eq_result