DynamoDB: Fix calculation when adding/subtracting decimals (#7365)
This commit is contained in:
parent
7f6c9cb1de
commit
7c87bddeae
@ -1,6 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import copy
|
import copy
|
||||||
import decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
||||||
@ -100,9 +100,14 @@ class DynamoType(object):
|
|||||||
if self.type != other.type:
|
if self.type != other.type:
|
||||||
raise TypeError("Different types of operandi is not allowed.")
|
raise TypeError("Different types of operandi is not allowed.")
|
||||||
if self.is_number():
|
if self.is_number():
|
||||||
self_value = float(self.value) if "." in self.value else int(self.value)
|
self_value: Union[Decimal, int] = (
|
||||||
other_value = float(other.value) if "." in other.value else int(other.value)
|
Decimal(self.value) if "." in self.value else int(self.value)
|
||||||
return DynamoType({DDBType.NUMBER: f"{self_value + other_value}"})
|
)
|
||||||
|
other_value: Union[Decimal, int] = (
|
||||||
|
Decimal(other.value) if "." in other.value else int(other.value)
|
||||||
|
)
|
||||||
|
total = self_value + other_value
|
||||||
|
return DynamoType({DDBType.NUMBER: f"{total}"})
|
||||||
else:
|
else:
|
||||||
raise IncorrectDataType()
|
raise IncorrectDataType()
|
||||||
|
|
||||||
@ -385,12 +390,7 @@ class Item(BaseModel):
|
|||||||
if set(update_action["Value"].keys()) == set(["N"]):
|
if set(update_action["Value"].keys()) == set(["N"]):
|
||||||
existing = self.attrs.get(attribute_name, DynamoType({"N": "0"}))
|
existing = self.attrs.get(attribute_name, DynamoType({"N": "0"}))
|
||||||
self.attrs[attribute_name] = DynamoType(
|
self.attrs[attribute_name] = DynamoType(
|
||||||
{
|
{"N": str(Decimal(existing.value) + Decimal(new_value))}
|
||||||
"N": str(
|
|
||||||
decimal.Decimal(existing.value)
|
|
||||||
+ decimal.Decimal(new_value)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
elif set(update_action["Value"].keys()) == set(["SS"]):
|
elif set(update_action["Value"].keys()) == set(["SS"]):
|
||||||
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
|
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -40,3 +42,50 @@ def test_update_different_map_elements_in_single_request(table_name=None):
|
|||||||
ExpressionAttributeValues={":MyCount": 5},
|
ExpressionAttributeValues={":MyCount": 5},
|
||||||
)
|
)
|
||||||
assert table.get_item(Key={"pk": "example_id"})["Item"]["MyTotalCount"] == 5
|
assert table.get_item(Key={"pk": "example_id"})["Item"]["MyTotalCount"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.aws_verified
|
||||||
|
@dynamodb_aws_verified()
|
||||||
|
def test_update_item_add_float(table_name=None):
|
||||||
|
table = boto3.resource("dynamodb", "us-east-1").Table(table_name)
|
||||||
|
|
||||||
|
# DECIMAL - DECIMAL
|
||||||
|
table.put_item(Item={"pk": "foo", "amount": Decimal(100), "nr": 5})
|
||||||
|
table.update_item(
|
||||||
|
Key={"pk": "foo"},
|
||||||
|
UpdateExpression="ADD amount :delta",
|
||||||
|
ExpressionAttributeValues={":delta": -Decimal("88.3")},
|
||||||
|
)
|
||||||
|
assert table.scan()["Items"][0]["amount"] == Decimal("11.7")
|
||||||
|
|
||||||
|
# DECIMAL + DECIMAL
|
||||||
|
table.update_item(
|
||||||
|
Key={"pk": "foo"},
|
||||||
|
UpdateExpression="ADD amount :delta",
|
||||||
|
ExpressionAttributeValues={":delta": Decimal("25.41")},
|
||||||
|
)
|
||||||
|
assert table.scan()["Items"][0]["amount"] == Decimal("37.11")
|
||||||
|
|
||||||
|
# DECIMAL + INT
|
||||||
|
table.update_item(
|
||||||
|
Key={"pk": "foo"},
|
||||||
|
UpdateExpression="ADD amount :delta",
|
||||||
|
ExpressionAttributeValues={":delta": 6},
|
||||||
|
)
|
||||||
|
assert table.scan()["Items"][0]["amount"] == Decimal("43.11")
|
||||||
|
|
||||||
|
# INT + INT
|
||||||
|
table.update_item(
|
||||||
|
Key={"pk": "foo"},
|
||||||
|
UpdateExpression="ADD nr :delta",
|
||||||
|
ExpressionAttributeValues={":delta": 1},
|
||||||
|
)
|
||||||
|
assert table.scan()["Items"][0]["nr"] == Decimal("6")
|
||||||
|
|
||||||
|
# INT + DECIMAL
|
||||||
|
table.update_item(
|
||||||
|
Key={"pk": "foo"},
|
||||||
|
UpdateExpression="ADD nr :delta",
|
||||||
|
ExpressionAttributeValues={":delta": Decimal("25.41")},
|
||||||
|
)
|
||||||
|
assert table.scan()["Items"][0]["nr"] == Decimal("31.41")
|
||||||
|
Loading…
Reference in New Issue
Block a user