DynamoDB: Fix calculation when adding/subtracting decimals (#7365)
This commit is contained in:
parent
7f6c9cb1de
commit
7c87bddeae
@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import copy
|
||||
import decimal
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
||||
@ -100,9 +100,14 @@ class DynamoType(object):
|
||||
if self.type != other.type:
|
||||
raise TypeError("Different types of operandi is not allowed.")
|
||||
if self.is_number():
|
||||
self_value = float(self.value) if "." in self.value else int(self.value)
|
||||
other_value = float(other.value) if "." in other.value else int(other.value)
|
||||
return DynamoType({DDBType.NUMBER: f"{self_value + other_value}"})
|
||||
self_value: Union[Decimal, int] = (
|
||||
Decimal(self.value) if "." in self.value else int(self.value)
|
||||
)
|
||||
other_value: Union[Decimal, int] = (
|
||||
Decimal(other.value) if "." in other.value else int(other.value)
|
||||
)
|
||||
total = self_value + other_value
|
||||
return DynamoType({DDBType.NUMBER: f"{total}"})
|
||||
else:
|
||||
raise IncorrectDataType()
|
||||
|
||||
@ -385,12 +390,7 @@ class Item(BaseModel):
|
||||
if set(update_action["Value"].keys()) == set(["N"]):
|
||||
existing = self.attrs.get(attribute_name, DynamoType({"N": "0"}))
|
||||
self.attrs[attribute_name] = DynamoType(
|
||||
{
|
||||
"N": str(
|
||||
decimal.Decimal(existing.value)
|
||||
+ decimal.Decimal(new_value)
|
||||
)
|
||||
}
|
||||
{"N": str(Decimal(existing.value) + Decimal(new_value))}
|
||||
)
|
||||
elif set(update_action["Value"].keys()) == set(["SS"]):
|
||||
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
|
||||
|
@ -1,3 +1,5 @@
|
||||
from decimal import Decimal
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
|
||||
@ -40,3 +42,50 @@ def test_update_different_map_elements_in_single_request(table_name=None):
|
||||
ExpressionAttributeValues={":MyCount": 5},
|
||||
)
|
||||
assert table.get_item(Key={"pk": "example_id"})["Item"]["MyTotalCount"] == 5
|
||||
|
||||
|
||||
@pytest.mark.aws_verified
|
||||
@dynamodb_aws_verified()
|
||||
def test_update_item_add_float(table_name=None):
|
||||
table = boto3.resource("dynamodb", "us-east-1").Table(table_name)
|
||||
|
||||
# DECIMAL - DECIMAL
|
||||
table.put_item(Item={"pk": "foo", "amount": Decimal(100), "nr": 5})
|
||||
table.update_item(
|
||||
Key={"pk": "foo"},
|
||||
UpdateExpression="ADD amount :delta",
|
||||
ExpressionAttributeValues={":delta": -Decimal("88.3")},
|
||||
)
|
||||
assert table.scan()["Items"][0]["amount"] == Decimal("11.7")
|
||||
|
||||
# DECIMAL + DECIMAL
|
||||
table.update_item(
|
||||
Key={"pk": "foo"},
|
||||
UpdateExpression="ADD amount :delta",
|
||||
ExpressionAttributeValues={":delta": Decimal("25.41")},
|
||||
)
|
||||
assert table.scan()["Items"][0]["amount"] == Decimal("37.11")
|
||||
|
||||
# DECIMAL + INT
|
||||
table.update_item(
|
||||
Key={"pk": "foo"},
|
||||
UpdateExpression="ADD amount :delta",
|
||||
ExpressionAttributeValues={":delta": 6},
|
||||
)
|
||||
assert table.scan()["Items"][0]["amount"] == Decimal("43.11")
|
||||
|
||||
# INT + INT
|
||||
table.update_item(
|
||||
Key={"pk": "foo"},
|
||||
UpdateExpression="ADD nr :delta",
|
||||
ExpressionAttributeValues={":delta": 1},
|
||||
)
|
||||
assert table.scan()["Items"][0]["nr"] == Decimal("6")
|
||||
|
||||
# INT + DECIMAL
|
||||
table.update_item(
|
||||
Key={"pk": "foo"},
|
||||
UpdateExpression="ADD nr :delta",
|
||||
ExpressionAttributeValues={":delta": Decimal("25.41")},
|
||||
)
|
||||
assert table.scan()["Items"][0]["nr"] == Decimal("31.41")
|
||||
|
Loading…
Reference in New Issue
Block a user