DynamoDB: Fix calculation when adding/subtracting decimals (#7365)

This commit is contained in:
Bert Blommers 2024-02-19 20:54:02 +00:00 committed by GitHub
parent 7f6c9cb1de
commit 7c87bddeae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 10 deletions

View File

@ -1,6 +1,6 @@
import base64
import copy
import decimal
from decimal import Decimal
from typing import Any, Dict, List, Optional, Union
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
@ -100,9 +100,14 @@ class DynamoType(object):
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.is_number():
self_value = float(self.value) if "." in self.value else int(self.value)
other_value = float(other.value) if "." in other.value else int(other.value)
return DynamoType({DDBType.NUMBER: f"{self_value + other_value}"})
self_value: Union[Decimal, int] = (
Decimal(self.value) if "." in self.value else int(self.value)
)
other_value: Union[Decimal, int] = (
Decimal(other.value) if "." in other.value else int(other.value)
)
total = self_value + other_value
return DynamoType({DDBType.NUMBER: f"{total}"})
else:
raise IncorrectDataType()
@ -385,12 +390,7 @@ class Item(BaseModel):
if set(update_action["Value"].keys()) == set(["N"]):
existing = self.attrs.get(attribute_name, DynamoType({"N": "0"}))
self.attrs[attribute_name] = DynamoType(
{
"N": str(
decimal.Decimal(existing.value)
+ decimal.Decimal(new_value)
)
}
{"N": str(Decimal(existing.value) + Decimal(new_value))}
)
elif set(update_action["Value"].keys()) == set(["SS"]):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))

View File

@ -1,3 +1,5 @@
from decimal import Decimal
import boto3
import pytest
@ -40,3 +42,50 @@ def test_update_different_map_elements_in_single_request(table_name=None):
ExpressionAttributeValues={":MyCount": 5},
)
assert table.get_item(Key={"pk": "example_id"})["Item"]["MyTotalCount"] == 5
@pytest.mark.aws_verified
@dynamodb_aws_verified()
def test_update_item_add_float(table_name=None):
table = boto3.resource("dynamodb", "us-east-1").Table(table_name)
# DECIMAL - DECIMAL
table.put_item(Item={"pk": "foo", "amount": Decimal(100), "nr": 5})
table.update_item(
Key={"pk": "foo"},
UpdateExpression="ADD amount :delta",
ExpressionAttributeValues={":delta": -Decimal("88.3")},
)
assert table.scan()["Items"][0]["amount"] == Decimal("11.7")
# DECIMAL + DECIMAL
table.update_item(
Key={"pk": "foo"},
UpdateExpression="ADD amount :delta",
ExpressionAttributeValues={":delta": Decimal("25.41")},
)
assert table.scan()["Items"][0]["amount"] == Decimal("37.11")
# DECIMAL + INT
table.update_item(
Key={"pk": "foo"},
UpdateExpression="ADD amount :delta",
ExpressionAttributeValues={":delta": 6},
)
assert table.scan()["Items"][0]["amount"] == Decimal("43.11")
# INT + INT
table.update_item(
Key={"pk": "foo"},
UpdateExpression="ADD nr :delta",
ExpressionAttributeValues={":delta": 1},
)
assert table.scan()["Items"][0]["nr"] == Decimal("6")
# INT + DECIMAL
table.update_item(
Key={"pk": "foo"},
UpdateExpression="ADD nr :delta",
ExpressionAttributeValues={":delta": Decimal("25.41")},
)
assert table.scan()["Items"][0]["nr"] == Decimal("31.41")